mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
reorganize files to prepare for MIOpen integration (#51)
* change olc cmake
* adding online compile to fwd-v4r5r2
* update scripts
* remane fwd-v4r5r2 to fwd-v6r1
* clean up
[ROCm/composable_kernel commit: 1264925422]
This commit is contained in:
@@ -6,14 +6,14 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||
include(TargetFlags)
|
||||
include(AddKernels)
|
||||
|
||||
#c++
|
||||
## C++
|
||||
enable_language(CXX)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
|
||||
|
||||
#OpenMP
|
||||
## OpenMP
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
||||
# workaround issue hipcc in rocm3.5 cannot find openmp
|
||||
set(OpenMP_CXX "${CMAKE_CXX_COMPILER}")
|
||||
@@ -35,56 +35,8 @@ set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
link_libraries(${OpenMP_gomp_LIBRARY})
|
||||
link_libraries(${OpenMP_pthread_LIBRARY})
|
||||
|
||||
#GPU backend
|
||||
if(DEVICE_BACKEND STREQUAL "AMD")
|
||||
find_package(HIP REQUIRED)
|
||||
endif()
|
||||
|
||||
#
|
||||
include_directories(BEFORE
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
|
||||
${PROJECT_SOURCE_DIR}/external/half/include
|
||||
${PROJECT_SOURCE_DIR}/driver/include
|
||||
${PROJECT_BINARY_DIR}/composable_kernel/include/utility
|
||||
)
|
||||
|
||||
if(DEVICE_BACKEND STREQUAL "AMD")
|
||||
include_directories(BEFORE
|
||||
${PROJECT_SOURCE_DIR}/external/rocm/include
|
||||
)
|
||||
endif()
|
||||
|
||||
if(DEVICE_BACKEND STREQUAL "AMD")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/synchronization.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/synchronization.hpp")
|
||||
endif()
|
||||
|
||||
add_subdirectory(driver)
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
|
||||
message("Compiling options for drivers: ${CMAKE_CXX_FLAGS}")
|
||||
|
||||
if(DEVICE_BACKEND STREQUAL "AMD")
|
||||
set(CONV_V2_SOURCE driver/conv_driver_v2.cpp)
|
||||
set(CONV_BWD_DATA_V2_SOURCE driver/conv_bwd_data_driver_v2.cpp)
|
||||
set(CONV_V2_OLC_SOURCE driver/conv_driver_v2_olc.cpp)
|
||||
endif()
|
||||
|
||||
add_executable(conv_driver_v2 ${CONV_V2_SOURCE})
|
||||
add_executable(conv_bwd_data_driver_v2 ${CONV_BWD_DATA_V2_SOURCE})
|
||||
add_executable(conv_driver_v2_olc ${CONV_V2_OLC_SOURCE})
|
||||
|
||||
target_include_directories(conv_driver_v2_olc PRIVATE driver/olCompiling/include/)
|
||||
|
||||
target_link_libraries(conv_driver_v2 PRIVATE modConv)
|
||||
target_link_libraries(conv_bwd_data_driver_v2 PRIVATE modConv)
|
||||
target_link_libraries(conv_driver_v2_olc PRIVATE modConv)
|
||||
|
||||
## HIP
|
||||
find_package(HIP REQUIRED)
|
||||
message(STATUS "Build with HIP ${hip_VERSION}")
|
||||
|
||||
add_subdirectory(host)
|
||||
|
||||
@@ -1,292 +0,0 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_CONTRACTION_V1R1_HPP
|
||||
#define CK_DRIVER_DYNAMIC_CONTRACTION_V1R1_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_contraction_v1r1.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AGKGM0GM1GridDesc,
|
||||
typename BGKGN0GN1GridDesc,
|
||||
typename CGM0GM1GN0GN1GridDesc,
|
||||
index_t GM1PerBlockGM11,
|
||||
index_t GN1PerBlockGN11,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM10,
|
||||
index_t M1N1ThreadClusterN10,
|
||||
index_t M1N1ThreadClusterM11,
|
||||
index_t M1N1ThreadClusterN11,
|
||||
typename ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
|
||||
typename ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_GM11,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
|
||||
typename BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_GN11,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
__host__ float
|
||||
driver_dynamic_contraction_v1r1(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AGKGM0GM1GridDesc& a_gk_gm0_gm1_grid_desc,
|
||||
const BGKGN0GN1GridDesc& b_gk_gn0_gn1_grid_desc,
|
||||
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
index_t nrepeat)
|
||||
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseContraction = GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AGKGM0GM1GridDesc,
|
||||
BGKGN0GN1GridDesc,
|
||||
CGM0GM1GN0GN1GridDesc,
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
|
||||
ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_GM11,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
|
||||
BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_GN11,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
const auto K = a_gk_gm0_gm1_grid_desc.GetLength(I0);
|
||||
|
||||
if(!GridwiseContraction::CheckValidity(
|
||||
a_gk_gm0_gm1_grid_desc, b_gk_gn0_gn1_grid_desc, c_gm0_gm1_gn0_gn1_grid_desc))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseDynamicContraction_km_kn0n1_mn0n1_v1r1 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_gk_gm0_gm10_gm11_grid_desc =
|
||||
GridwiseContraction::MakeAGKGM0GM10GM11GridDescriptor(a_gk_gm0_gm1_grid_desc);
|
||||
const auto b_gk_gn0_gn10_gn11_grid_desc =
|
||||
GridwiseContraction::MakeBGKGN0GN10GN11GridDescriptor(b_gk_gn0_gn1_grid_desc);
|
||||
|
||||
using AGKGM0GM10GM11GridDesc = decltype(a_gk_gm0_gm10_gm11_grid_desc);
|
||||
using BGKGN0GN10GN11GridDesc = decltype(b_gk_gn0_gn10_gn11_grid_desc);
|
||||
|
||||
// c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc
|
||||
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc =
|
||||
GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc);
|
||||
|
||||
using CGM10BM0BM1GN10BN0BN1GridDesc = decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc);
|
||||
|
||||
// c_blockid_to_gm10_gn10_block_cluster_adaptor
|
||||
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor =
|
||||
GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(c_gm0_gm1_gn0_gn1_grid_desc);
|
||||
|
||||
using CBlockIdToGM10GN10BlockClusterAdaptor =
|
||||
decltype(c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_gm0_gm1_gn0_gn1_grid_desc);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(K);
|
||||
|
||||
const bool has_double_tail_k_block_loop =
|
||||
GridwiseContraction::CalculateHasDoubleTailKBlockLoop(K);
|
||||
|
||||
{
|
||||
std::cout << "a_gk_gm0_gm10_gm11_grid_desc{" << a_gk_gm0_gm10_gm11_grid_desc.GetLength(I0)
|
||||
<< ", " << a_gk_gm0_gm10_gm11_grid_desc.GetLength(I1) << ", "
|
||||
<< a_gk_gm0_gm10_gm11_grid_desc.GetLength(I2) << ", "
|
||||
<< a_gk_gm0_gm10_gm11_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "b_gk_gn0_gn10_gn11_grid_desc{" << b_gk_gn0_gn10_gn11_grid_desc.GetLength(I0)
|
||||
<< ", " << b_gk_gn0_gn10_gn11_grid_desc.GetLength(I1) << ", "
|
||||
<< b_gk_gn0_gn10_gn11_grid_desc.GetLength(I2) << ", "
|
||||
<< b_gk_gn0_gn10_gn11_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc{ "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I0) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I1) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I2) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I3) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I4) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r1<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGKGM0GM10GM11GridDesc>,
|
||||
remove_reference_t<BGKGN0GN10GN11GridDesc>,
|
||||
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
|
||||
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_gk_gm0_gm10_gm11_grid_desc,
|
||||
b_gk_gn0_gn10_gn11_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r1<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGKGM0GM10GM11GridDesc>,
|
||||
remove_reference_t<BGKGN0GN10GN11GridDesc>,
|
||||
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
|
||||
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_gk_gm0_gm10_gm11_grid_desc,
|
||||
b_gk_gn0_gn10_gn11_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r1<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGKGM0GM10GM11GridDesc>,
|
||||
remove_reference_t<BGKGN0GN10GN11GridDesc>,
|
||||
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
|
||||
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_gk_gm0_gm10_gm11_grid_desc,
|
||||
b_gk_gn0_gn10_gn11_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r1<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGKGM0GM10GM11GridDesc>,
|
||||
remove_reference_t<BGKGN0GN10GN11GridDesc>,
|
||||
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
|
||||
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_gk_gm0_gm10_gm11_grid_desc,
|
||||
b_gk_gn0_gn10_gn11_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -13,19 +13,19 @@ template <index_t BlockSize,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AGKGM0GM1GridDesc,
|
||||
typename BGKGN0GN1GridDesc,
|
||||
typename CGM0GM1GN0GN1GridDesc,
|
||||
typename AGridDesc_GK0_GM0_GM1_GK1,
|
||||
typename BGridDesc_GK0_GN0_GN1_GK1,
|
||||
typename CGridDesc_GM0_GM1_GN0_GN1,
|
||||
index_t GM1PerBlockGM11,
|
||||
index_t GN1PerBlockGN11,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM10,
|
||||
index_t M1N1ThreadClusterN10,
|
||||
index_t M1N1ThreadClusterM11,
|
||||
index_t M1N1ThreadClusterN11,
|
||||
index_t GK0PerBlock,
|
||||
index_t BM1PerThreadBM11,
|
||||
index_t BN1PerThreadBN11,
|
||||
index_t BK0PerThread,
|
||||
index_t BM10BN10ThreadClusterBM100,
|
||||
index_t BM10BN10ThreadClusterBN100,
|
||||
index_t BM10BN10ThreadClusterBM101,
|
||||
index_t BM10BN10ThreadClusterBN101,
|
||||
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -52,9 +52,9 @@ __host__ float
|
||||
driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AGKGM0GM1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc,
|
||||
const BGKGN0GN1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc,
|
||||
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc,
|
||||
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
|
||||
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
|
||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
@@ -71,79 +71,83 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseContraction = GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AGKGM0GM1GridDesc,
|
||||
BGKGN0GN1GridDesc,
|
||||
CGM0GM1GN0GN1GridDesc,
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
using GridwiseContraction =
|
||||
GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AGridDesc_GK0_GM0_GM1_GK1,
|
||||
BGridDesc_GK0_GN0_GN1_GK1,
|
||||
CGridDesc_GM0_GM1_GN0_GN1,
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
GK0PerBlock,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM100,
|
||||
BM10BN10ThreadClusterBN100,
|
||||
BM10BN10ThreadClusterBM101,
|
||||
BM10BN10ThreadClusterBN101,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0);
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
|
||||
|
||||
if(!GridwiseContraction::CheckValidity(
|
||||
a_gk0_gm0_gm1_gk1_grid_desc, b_gk0_gn0_gn1_gk1_grid_desc, c_gm0_gm1_gn0_gn1_grid_desc))
|
||||
a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseDynamicContraction_km_kn0n1_mn0n1_v1r1 has invalid setting");
|
||||
throw std::runtime_error("wrong! "
|
||||
"GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_"
|
||||
"GM0_GM1_GN0_GN1 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_gk0_gm0_gm10_gm11_gk1_grid_desc =
|
||||
GridwiseContraction::MakeAGK0GM0GM10GM11GK1GridDescriptor(a_gk0_gm0_gm1_gk1_grid_desc);
|
||||
const auto b_gk0_gn0_gn10_gn11_gk1_grid_desc =
|
||||
GridwiseContraction::MakeBGK0GN0GN10GN11GK1GridDescriptor(b_gk0_gn0_gn1_gk1_grid_desc);
|
||||
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 =
|
||||
GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1);
|
||||
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 =
|
||||
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||
|
||||
using AGK0GM0GM10GM11GK1GridDesc = decltype(a_gk0_gm0_gm10_gm11_gk1_grid_desc);
|
||||
using BGK0GN0GN10GN11GK1GridDesc = decltype(b_gk0_gn0_gn10_gn11_gk1_grid_desc);
|
||||
using AGridDesc_GK0_GM0_GM10_GM11_GK1 = decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1);
|
||||
using BGridDesc_GK0_GN0_GN10_GN11_GK1 = decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1);
|
||||
|
||||
// c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc
|
||||
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc =
|
||||
GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc);
|
||||
// c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
|
||||
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
|
||||
GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using CGM10BM0BM1GN10BN0BN1GridDesc = decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc);
|
||||
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1);
|
||||
|
||||
// c_blockid_to_gm10_gn10_block_cluster_adaptor
|
||||
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor =
|
||||
GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(c_gm0_gm1_gn0_gn1_grid_desc);
|
||||
// c_grid_block_cluster_blockid_to_gm10_gn10
|
||||
const auto c_grid_block_cluster_blockid_to_gm10_gn10 =
|
||||
GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using CBlockIdToGM10GN10BlockClusterAdaptor =
|
||||
decltype(c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
using CGridBlockCluster_BlockId_To_GM10_GN10 =
|
||||
decltype(c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
|
||||
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_gm0_gm1_gn0_gn1_grid_desc);
|
||||
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0);
|
||||
|
||||
@@ -151,41 +155,41 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
|
||||
GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0);
|
||||
|
||||
{
|
||||
std::cout << "a_gk0_gm0_gm10_gm11_gk1_grid_desc{"
|
||||
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I1) << ", "
|
||||
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I2) << ", "
|
||||
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I3) << ", "
|
||||
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I4) << "}" << std::endl;
|
||||
std::cout << "a_grid_desc_gk0_gm0_gm10_gm11_gk1{"
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I1) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I2) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I3) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I4) << "}" << std::endl;
|
||||
|
||||
std::cout << "b_gk0_gn0_gn10_gn11_gk1_grid_desc{"
|
||||
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I1) << ", "
|
||||
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I2) << ", "
|
||||
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I3) << ", "
|
||||
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I4) << "}" << std::endl;
|
||||
std::cout << "b_grid_desc_gk0_gn0_gn10_gn11_gk1{"
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I0) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I1) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I2) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I3) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I4) << "}" << std::endl;
|
||||
|
||||
std::cout << "c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc{ "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I0) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I1) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I2) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I3) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I4) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I5) << "}" << std::endl;
|
||||
std::cout << "c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1{ "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I0) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I1) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I2) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I3) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I4) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r1<
|
||||
const auto kernel = kernel_dynamic_contraction_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
|
||||
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
|
||||
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
|
||||
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
@@ -198,21 +202,21 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r1<
|
||||
const auto kernel = kernel_dynamic_contraction_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
|
||||
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
|
||||
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
|
||||
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
@@ -225,21 +229,21 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r1<
|
||||
const auto kernel = kernel_dynamic_contraction_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
|
||||
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
|
||||
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
|
||||
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
@@ -252,21 +256,21 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r1<
|
||||
const auto kernel = kernel_dynamic_contraction_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
|
||||
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
|
||||
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
|
||||
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
@@ -279,10 +283,10 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
|
||||
@@ -1,387 +0,0 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_GEMM_V1
|
||||
#define CK_DRIVER_DYNAMIC_GEMM_V1
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_v1r1.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
typename CBlockClusterDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM10,
|
||||
index_t M1N1ThreadClusterN10,
|
||||
index_t M1N1ThreadClusterM11,
|
||||
index_t M1N1ThreadClusterN11,
|
||||
typename ABlockTransferThreadSliceLengths_K_M,
|
||||
typename ABlockTransferThreadClusterLengths_K_M,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_M,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K_N,
|
||||
typename BBlockTransferThreadClusterLengths_K_N,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_N,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGlobalIteratorHacks,
|
||||
typename BGlobalIteratorHacks,
|
||||
typename CGlobalIteratorHacks,
|
||||
typename AGlobalMoveSliceWindowIteratorHacks,
|
||||
typename BGlobalMoveSliceWindowIteratorHacks>
|
||||
__host__ float launch_kernel_dynamic_gemm_v1r1(const FloatAB* p_a_global,
|
||||
const FloatAB* p_b_global,
|
||||
FloatC* p_c_global,
|
||||
const AGlobalDesc& a_k_m_global_desc,
|
||||
const BGlobalDesc& b_k_n_global_desc,
|
||||
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
|
||||
const CBlockClusterDesc& c_block_cluster_desc,
|
||||
AGlobalIteratorHacks,
|
||||
BGlobalIteratorHacks,
|
||||
CGlobalIteratorHacks,
|
||||
AGlobalMoveSliceWindowIteratorHacks,
|
||||
BGlobalMoveSliceWindowIteratorHacks,
|
||||
index_t nrepeat)
|
||||
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto M = a_k_m_global_desc.GetLength(I1);
|
||||
const auto N = b_k_n_global_desc.GetLength(I1);
|
||||
const auto K = a_k_m_global_desc.GetLength(I0);
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
constexpr auto M1 = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{};
|
||||
constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{};
|
||||
|
||||
if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// GEMM
|
||||
using gridwise_gemm =
|
||||
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AGlobalDesc,
|
||||
BGlobalDesc,
|
||||
CGlobalDesc,
|
||||
CBlockClusterDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M,
|
||||
ABlockTransferThreadClusterLengths_K_M,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N,
|
||||
BBlockTransferThreadClusterLengths_K_N,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGlobalIteratorHacks,
|
||||
BGlobalIteratorHacks,
|
||||
CGlobalIteratorHacks,
|
||||
AGlobalMoveSliceWindowIteratorHacks,
|
||||
BGlobalMoveSliceWindowIteratorHacks>;
|
||||
|
||||
const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
|
||||
|
||||
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
|
||||
|
||||
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc));
|
||||
DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc));
|
||||
DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc));
|
||||
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
|
||||
|
||||
a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc);
|
||||
b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc);
|
||||
c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc);
|
||||
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,125 +0,0 @@
|
||||
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <index_t N0_,
|
||||
typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad(
|
||||
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gk_gm0_gm1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(I1, K)),
|
||||
make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1, 2>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_c_hi_wi_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto N0 = Number<N0_>{};
|
||||
const auto N1 = N / N0;
|
||||
|
||||
const auto in_n0_n1_c_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_c_hip_wip_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4>{}, Sequence<5, 6>{}));
|
||||
|
||||
const auto in_gk_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n0_n1_c_y_ho_x_wo_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_pass_through_transform(N0),
|
||||
make_merge_transform(make_tuple(N1, Ho, Wo))),
|
||||
make_tuple(Sequence<2, 3, 5>{}, Sequence<0>{}, Sequence<1, 4, 6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_n_k_howo_grid_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo));
|
||||
|
||||
const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
out_n_k_howo_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<N0>{}, N1)),
|
||||
make_unmerge_transform(make_tuple(I1, K)),
|
||||
make_pass_through_transform(Ho * Wo)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||
|
||||
const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
out_n0_n1_1_k_howo_grid_desc,
|
||||
make_tuple(make_pass_through_transform(I1),
|
||||
make_pass_through_transform(K),
|
||||
make_pass_through_transform(Number<N0>{}),
|
||||
make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))),
|
||||
make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
return make_tuple(
|
||||
wei_gk_gm0_gm1_grid_desc, in_gk_gn0_gn1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,129 @@
|
||||
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
|
||||
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = K;
|
||||
const auto GemmN = N * Ho * Wo;
|
||||
const auto GemmK = C * Y * X;
|
||||
const auto GemmK0 = GemmK / GemmK1;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
wei_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_grid_desc =
|
||||
transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
in_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,5 +1,5 @@
|
||||
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5R2_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5R2_NCHW_KCYX_NKHW_HPP
|
||||
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
@@ -17,10 +17,10 @@ template <typename... Wei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t N0Value,
|
||||
index_t C0Value>
|
||||
typename N0Type,
|
||||
typename C0Type>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
|
||||
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
||||
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
||||
@@ -28,8 +28,8 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<N0Value>,
|
||||
Number<C0Value>)
|
||||
const N0Type& N0,
|
||||
const C0Type& C0)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -61,9 +61,6 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
constexpr auto N0 = Number<N0Value>{};
|
||||
constexpr auto C0 = Number<C0Value>{};
|
||||
|
||||
const auto N1 = N / N0;
|
||||
const auto C1 = C / C0;
|
||||
|
||||
@@ -109,7 +106,7 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
|
||||
|
||||
const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
out_n_k_howo_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<N0>{}, N1)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
|
||||
make_unmerge_transform(make_tuple(I1, K)),
|
||||
make_pass_through_transform(Ho * Wo)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
@@ -119,7 +116,7 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
|
||||
out_n0_n1_1_k_howo_grid_desc,
|
||||
make_tuple(make_pass_through_transform(I1),
|
||||
make_pass_through_transform(K),
|
||||
make_pass_through_transform(Number<N0>{}),
|
||||
make_pass_through_transform(N0),
|
||||
make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))),
|
||||
make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_gemm_v2.hpp"
|
||||
#include "threadwise_contraction.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
|
||||
@@ -4,43 +4,43 @@
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_gemm_v2.hpp"
|
||||
#include "threadwise_contraction.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
|
||||
// C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
|
||||
// A and B are visable to the whole block, C is distributed among each thread
|
||||
// Assume:
|
||||
// 1. A:
|
||||
// 1. AK0MK1BlockDesc is known at compile-time
|
||||
// 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
|
||||
// 2. ABlockBuffer is DynamicBuffer
|
||||
// 2. B:
|
||||
// 1. BK0NK1BlockDesc is known at compile-time
|
||||
// 1. BBlockDesc_BK0_BN_BK1 is known at compile-time
|
||||
// 2. BBlockBuffer is DynamicBuffer
|
||||
// 3. C:
|
||||
// 1. CM0M1N0N1ThreadDesc is known at compile-time
|
||||
// 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
|
||||
// 2. CThreadBuffer is StaticBuffer
|
||||
// Also assume:
|
||||
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
|
||||
// BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AK0MK1BlockDesc,
|
||||
typename BK0NK1BlockDesc,
|
||||
index_t M1PerThreadM11,
|
||||
index_t N1PerThreadN11,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM100,
|
||||
index_t M1N1ThreadClusterN100,
|
||||
index_t M1N1ThreadClusterM101,
|
||||
index_t M1N1ThreadClusterN101,
|
||||
index_t AThreadCopyScalarPerVector_M11,
|
||||
index_t BThreadCopyScalarPerVector_N11,
|
||||
typename std::enable_if<AK0MK1BlockDesc::IsKnownAtCompileTime() &&
|
||||
BK0NK1BlockDesc::IsKnownAtCompileTime(),
|
||||
typename ABlockDesc_BK0_BM_BK1,
|
||||
typename BBlockDesc_BK0_BN_BK1,
|
||||
index_t BM1PerThreadBM11,
|
||||
index_t BN1PerThreadBN11,
|
||||
index_t BK0PerThread,
|
||||
index_t BM10BN10ThreadClusterBM100,
|
||||
index_t BM10BN10ThreadClusterBN100,
|
||||
index_t BM10BN10ThreadClusterBM101,
|
||||
index_t BM10BN10ThreadClusterBN101,
|
||||
index_t AThreadCopyScalarPerVector_BM11,
|
||||
index_t BThreadCopyScalarPerVector_BN11,
|
||||
typename std::enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
|
||||
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2
|
||||
struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
|
||||
{
|
||||
using AIndex = MultiIndex<3>;
|
||||
using BIndex = MultiIndex<3>;
|
||||
@@ -51,138 +51,144 @@ struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr index_t K0 = AK0MK1BlockDesc{}.GetLength(I0);
|
||||
static constexpr index_t K1 = AK0MK1BlockDesc{}.GetLength(I2);
|
||||
static constexpr index_t M = AK0MK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t N = BK0NK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0);
|
||||
static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2);
|
||||
static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1);
|
||||
static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1);
|
||||
|
||||
static constexpr index_t M100 = M1N1ThreadClusterM100;
|
||||
static constexpr index_t N100 = M1N1ThreadClusterN100;
|
||||
static constexpr index_t BM100 = BM10BN10ThreadClusterBM100;
|
||||
static constexpr index_t BN100 = BM10BN10ThreadClusterBN100;
|
||||
|
||||
static constexpr index_t M101 = M1N1ThreadClusterM101;
|
||||
static constexpr index_t N101 = M1N1ThreadClusterN101;
|
||||
static constexpr index_t BM101 = BM10BN10ThreadClusterBM101;
|
||||
static constexpr index_t BN101 = BM10BN10ThreadClusterBN101;
|
||||
|
||||
static constexpr index_t M11 = M1PerThreadM11;
|
||||
static constexpr index_t N11 = N1PerThreadN11;
|
||||
static constexpr index_t BM11 = BM1PerThreadBM11;
|
||||
static constexpr index_t BN11 = BN1PerThreadBN11;
|
||||
|
||||
static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11;
|
||||
static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11;
|
||||
static constexpr index_t BM1 =
|
||||
BM10BN10ThreadClusterBM100 * BM10BN10ThreadClusterBM101 * BM1PerThreadBM11;
|
||||
static constexpr index_t BN1 =
|
||||
BM10BN10ThreadClusterBN100 * BM10BN10ThreadClusterBN101 * BN1PerThreadBN11;
|
||||
|
||||
static constexpr index_t M0 = M / M1;
|
||||
static constexpr index_t N0 = N / N1;
|
||||
static constexpr index_t BM0 = BM / BM1;
|
||||
static constexpr index_t BN0 = BN / BN1;
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAK0M0M1K1BlockDescriptor(const AK0MK1BlockDesc& a_k0_m_k1_block_desc)
|
||||
MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
|
||||
{
|
||||
const auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
a_k0_m_k1_block_desc,
|
||||
make_tuple(make_pass_through_transform(Number<K0>{}),
|
||||
make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{})),
|
||||
make_pass_through_transform(Number<K1>{})),
|
||||
const auto a_block_bk0_bm0_bm1_bk1 = transform_dynamic_tensor_descriptor(
|
||||
a_block_desc_bk0_bm_bk1,
|
||||
make_tuple(make_pass_through_transform(Number<BK0>{}),
|
||||
make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})),
|
||||
make_pass_through_transform(Number<BK1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return a_k0_m0_m1_k1_block_desc;
|
||||
return a_block_bk0_bm0_bm1_bk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBK0N0N1K1BlockDescriptor(const BK0NK1BlockDesc& b_k0_n_k1_block_desc)
|
||||
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
|
||||
{
|
||||
const auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
b_k0_n_k1_block_desc,
|
||||
make_tuple(make_pass_through_transform(Number<K0>{}),
|
||||
make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{})),
|
||||
make_pass_through_transform(Number<K1>{})),
|
||||
const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_dynamic_tensor_descriptor(
|
||||
b_block_desc_bk0_bn_bk1,
|
||||
make_tuple(make_pass_through_transform(Number<BK0>{}),
|
||||
make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})),
|
||||
make_pass_through_transform(Number<BK1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return b_k0_n0_n1_k1_block_desc;
|
||||
return b_block_desc_bk0_bn0_bn1_bk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor()
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN()
|
||||
{
|
||||
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
// lower: [M, N]
|
||||
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor =
|
||||
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
// lower: [BM, BN]
|
||||
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
Number<M0>{}, Number<M100>{}, Number<M101>{}, Number<M11>{})),
|
||||
Number<BM0>{}, Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<N0>{}, Number<N100>{}, Number<N101>{}, Number<N11>{}))),
|
||||
Number<BN0>{}, Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
|
||||
|
||||
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor;
|
||||
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor()
|
||||
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1()
|
||||
{
|
||||
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
// lower: [M0, M1, N0, N1]
|
||||
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor =
|
||||
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
// lower: [BM0, BM1, BN0, BN1]
|
||||
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_pass_through_transform(Number<M0>{}),
|
||||
make_tuple(make_pass_through_transform(Number<BM0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<M100>{}, Number<M101>{}, Number<M11>{})),
|
||||
make_pass_through_transform(Number<N0>{}),
|
||||
make_tuple(Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
|
||||
make_pass_through_transform(Number<BN0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<N100>{}, Number<N101>{}, Number<N11>{}))),
|
||||
make_tuple(Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
|
||||
|
||||
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor;
|
||||
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths()
|
||||
__host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
|
||||
{
|
||||
return Sequence<M0, M11, N0, N11>{};
|
||||
return Sequence<BM0, BM11, BN0, BN11>{};
|
||||
}
|
||||
|
||||
static constexpr auto a_k0_m0_m1_k1_block_desc_ =
|
||||
MakeAK0M0M1K1BlockDescriptor(AK0MK1BlockDesc{});
|
||||
static constexpr auto b_k0_n0_n1_k1_block_desc_ =
|
||||
MakeBK0N0N1K1BlockDescriptor(BK0NK1BlockDesc{});
|
||||
static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ =
|
||||
MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{});
|
||||
|
||||
static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ =
|
||||
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{});
|
||||
|
||||
public:
|
||||
__device__ BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2()
|
||||
: c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock(
|
||||
__device__ BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2()
|
||||
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||
get_thread_local_1d_id())},
|
||||
a_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)},
|
||||
b_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)}
|
||||
{
|
||||
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
|
||||
BK0NK1BlockDesc::IsKnownAtCompileTime(),
|
||||
static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
|
||||
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(BlockSize == M101 * M100 * N101 * N100,
|
||||
static_assert(BlockSize == BM101 * BM100 * BN101 * BN100,
|
||||
"wrong! blocksize and cluster size not consistent");
|
||||
|
||||
static_assert(M % M1 == 0 && N % N1 == 0, "wrong!");
|
||||
static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!");
|
||||
|
||||
static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0),
|
||||
static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) ==
|
||||
BBlockDesc_BK0_BN_BK1{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent");
|
||||
|
||||
// TODO: remove this restriction
|
||||
static_assert(M0 == 2 && N0 == 2, "wrong");
|
||||
static_assert(BM0 == 2 && BN0 == 2, "wrong");
|
||||
}
|
||||
|
||||
__device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id)
|
||||
__device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id)
|
||||
{
|
||||
// lower: [M0, M1, N0, N1]
|
||||
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor();
|
||||
// lower: [BM0, BM1, BN0, BN1]
|
||||
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
constexpr auto adaptor0 =
|
||||
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1();
|
||||
|
||||
// lower: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
// upper: [Tid, M0, M11, N0, N11]
|
||||
// lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
// upper: [Tid, BM0, BM11, BN0, BN11]
|
||||
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)),
|
||||
make_pass_through_transform(M0),
|
||||
make_pass_through_transform(M11),
|
||||
make_pass_through_transform(N0),
|
||||
make_pass_through_transform(N11)),
|
||||
make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)),
|
||||
make_pass_through_transform(BM0),
|
||||
make_pass_through_transform(BM11),
|
||||
make_pass_through_transform(BN0),
|
||||
make_pass_through_transform(BN11)),
|
||||
make_tuple(
|
||||
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
@@ -192,201 +198,203 @@ struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2
|
||||
return adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id(), 0, 0, 0, 0));
|
||||
}
|
||||
|
||||
template <typename CM0M1N0N1ThreadDesc,
|
||||
template <typename CThreadDesc_BM0_BM11_BN0_BN11,
|
||||
typename ABlockBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const CM0M1N0N1ThreadDesc& c_m0_m1_n0_n1_thread_desc,
|
||||
__device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11& c_m0_m1_n0_n1_thread_desc,
|
||||
const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(),
|
||||
static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
// TODO: remove this restriction
|
||||
static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 &&
|
||||
CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0,
|
||||
static_assert(BM0 == 2 && BN0 == 2 &&
|
||||
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I0) == BM0 &&
|
||||
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0,
|
||||
"wrong");
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatA>(
|
||||
a_k0_m0_m1_k1_thread_desc_.GetElementSpaceSize());
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatB>(
|
||||
b_k0_n0_n1_k1_thread_desc_.GetElementSpaceSize());
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_gemm =
|
||||
ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1<FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_k0_m0_m1_k1_thread_desc_),
|
||||
decltype(b_k0_n0_n1_k1_thread_desc_),
|
||||
CM0M1N0N1ThreadDesc,
|
||||
Sequence<KPerThread, K1>,
|
||||
Sequence<1, M1PerThreadM11>,
|
||||
Sequence<1, N1PerThreadN11>>{};
|
||||
constexpr auto threadwise_contraction =
|
||||
ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
|
||||
FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
|
||||
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
|
||||
CThreadDesc_BM0_BM11_BN0_BN11,
|
||||
Sequence<BK0PerThread, BK1>,
|
||||
Sequence<1, BM1PerThreadBM11>,
|
||||
Sequence<1, BN1PerThreadBN11>>{};
|
||||
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
|
||||
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_k0_m0_m1_k1_thread_desc_,
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
|
||||
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_k0_n0_n1_k1_thread_desc_,
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
|
||||
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_block_buf,
|
||||
b_k0_n0_n1_k1_thread_desc_,
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
|
||||
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
a_block_buf,
|
||||
a_k0_m0_m1_k1_thread_desc_,
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
|
||||
// loop over rest of k
|
||||
static_for<KPerThread, K0, KPerThread>{}([&](auto k) {
|
||||
// loop over rest of bk0
|
||||
static_for<BK0PerThread, BK0, BK0PerThread>{}([&](auto bk0) {
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
|
||||
make_tuple(k, I0, I0, I0),
|
||||
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(bk0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_k0_m0_m1_k1_thread_desc_,
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
|
||||
make_tuple(k, I0, I0, I0),
|
||||
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(bk0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_k0_n0_n1_k1_thread_desc_,
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
|
||||
make_tuple(k, I1, I0, I0),
|
||||
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(bk0, I1, I0, I0),
|
||||
b_block_buf,
|
||||
b_k0_n0_n1_k1_thread_desc_,
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
|
||||
make_tuple(k, I1, I0, I0),
|
||||
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(bk0, I1, I0, I0),
|
||||
a_block_buf,
|
||||
a_k0_m0_m1_k1_thread_desc_,
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
});
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
}
|
||||
|
||||
private:
|
||||
// A[K0, M0, M1, K1]
|
||||
static constexpr auto a_k0_m0_m1_k1_thread_desc_ =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{}, Number<K1>{}));
|
||||
// A[BK0, BM0, BM1, BK1]
|
||||
static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
||||
Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThreadBM11>{}, Number<BK1>{}));
|
||||
|
||||
// B[K0, N0, N1, K1]
|
||||
static constexpr auto b_k0_n0_n1_k1_thread_desc_ =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{}, Number<K1>{}));
|
||||
// B[BK0, BN0, BN1, BK1]
|
||||
static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
||||
Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThreadBN11>{}, Number<BK1>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1<
|
||||
FloatA,
|
||||
FloatA,
|
||||
decltype(a_k0_m0_m1_k1_block_desc_),
|
||||
decltype(a_k0_m0_m1_k1_thread_desc_),
|
||||
Sequence<KPerThread, 1, M1PerThreadM11, K1>, // SliceLengths
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder
|
||||
Sequence<1, 1, M1PerThreadM11, K1>, // SrcVectorTensorLengths
|
||||
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
|
||||
decltype(a_block_desc_bk0_bm0_bm1_bk1_),
|
||||
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
|
||||
Sequence<BK0PerThread, 1, BM1PerThreadBM11, BK1>, // SliceLengths
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder
|
||||
Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths
|
||||
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
|
||||
|
||||
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1<
|
||||
FloatB,
|
||||
FloatB,
|
||||
decltype(b_k0_n0_n1_k1_block_desc_),
|
||||
decltype(b_k0_n0_n1_k1_thread_desc_),
|
||||
Sequence<KPerThread, 1, N1PerThreadN11, K1>, // SliceLengths
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder
|
||||
Sequence<1, 1, N1PerThreadN11, K1>, // SrcVectorTensorLengths
|
||||
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
|
||||
decltype(b_block_desc_bk0_bn0_bn1_bk1_),
|
||||
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
|
||||
Sequence<BK0PerThread, 1, BN1PerThreadBN11, BK1>, // SliceLengths
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder
|
||||
Sequence<1, 1, BN1PerThreadBN11, BK1>, // SrcVectorTensorLengths
|
||||
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
|
||||
|
||||
CIndex c_thread_origin_data_idx_;
|
||||
|
||||
|
||||
@@ -1,681 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R1_HPP
|
||||
#define CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R1_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_v2r2.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseContraction,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AGKGM0GM10GM11GridDesc,
|
||||
typename BGKGN0GN10GN11GridDesc,
|
||||
typename CGM10BM0BM1GN10BN0BN1GridDesc,
|
||||
typename CBlockIdToGM10GN10BlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_contraction_v1r1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGKGM0GM10GM11GridDesc a_gk_gm0_gm10_gm11_grid_desc,
|
||||
const BGKGN0GN10GN11GridDesc b_gk_gn0_gn10_gn11_grid_desc,
|
||||
const CGM10BM0BM1GN10BN0BN1GridDesc c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
const CBlockIdToGM10GN10BlockClusterAdaptor
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseContraction::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_gk_gm0_gm10_gm11_grid_desc,
|
||||
b_gk_gn0_gn10_gn11_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AGKGM0GM1GridDesc,
|
||||
typename BGKGN0GN1GridDesc,
|
||||
typename CGM0GM1GN0GN1GridDesc,
|
||||
index_t GM1PerBlockGM11,
|
||||
index_t GN1PerBlockGN11,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThreadM111,
|
||||
index_t N1PerThreadN111,
|
||||
index_t KPerThread,
|
||||
index_t M11N11ThreadClusterM1100,
|
||||
index_t M11N11ThreadClusterN1100,
|
||||
index_t M11N11ThreadClusterM1101,
|
||||
index_t M11N11ThreadClusterN1101,
|
||||
typename ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
|
||||
typename ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_GM11,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
|
||||
typename BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_GN11,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
struct GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
// GM0 and GN0 need to known at compile-time
|
||||
static constexpr auto GM0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I0);
|
||||
static constexpr auto GN0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I2);
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_GM11>{},
|
||||
Number<BBlockTransferDstScalarPerVector_GN11>{},
|
||||
Number<M1PerThreadM111>{},
|
||||
Number<N1PerThreadN111>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_gk_gm0_gm10_gm11_block_desc =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_gk_gn0_gn10_gn11_block_desc =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}), max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGKGM0GM1GridDesc& a_gk_gm0_gm1_grid_desc,
|
||||
const BGKGN0GN1GridDesc& b_gk_gn0_gn1_grid_desc,
|
||||
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(GM0)>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<decltype(GN0)>>::value,
|
||||
"wrong! GM0 and GN0 need to be known at compile-time");
|
||||
|
||||
const auto GM1 = a_gk_gm0_gm1_grid_desc.GetLength(I2);
|
||||
const auto GN1 = b_gk_gn0_gn1_grid_desc.GetLength(I2);
|
||||
const auto GK = a_gk_gm0_gm1_grid_desc.GetLength(I0);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
|
||||
return ((GM0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I0) &&
|
||||
GM1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1) &&
|
||||
GN0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I2) &&
|
||||
GN1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3) &&
|
||||
GM0 == a_gk_gm0_gm1_grid_desc.GetLength(I1) &&
|
||||
GM1 == a_gk_gm0_gm1_grid_desc.GetLength(I2) &&
|
||||
GN0 == b_gk_gn0_gn1_grid_desc.GetLength(I1) &&
|
||||
GN1 == b_gk_gn0_gn1_grid_desc.GetLength(I2) &&
|
||||
GK == b_gk_gn0_gn1_grid_desc.GetLength(I0)) &&
|
||||
(GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK % KPerBlock == 0));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
|
||||
{
|
||||
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
|
||||
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t GM11 = GM1PerBlockGM11;
|
||||
constexpr index_t GN11 = GN1PerBlockGN11;
|
||||
|
||||
const index_t GM10 = GM1 / GM11;
|
||||
const index_t GN10 = GN1 / GN11;
|
||||
|
||||
const index_t grid_size = GM10 * GN10;
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK)
|
||||
{
|
||||
const bool has_main_k_block_loop = (GK + KPerBlock) / (2 * KPerBlock) > 1;
|
||||
|
||||
return has_main_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK)
|
||||
{
|
||||
const bool has_double_tail_k_block_loop = (GK / KPerBlock) % 2 == 0;
|
||||
|
||||
return has_double_tail_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAGKGM0GM10GM11GridDescriptor(const AGKGM0GM1GridDesc& a_gk_gm0_gm1_grid_desc)
|
||||
{
|
||||
const auto GK = a_gk_gm0_gm1_grid_desc.GetLength(I0);
|
||||
const auto GM1 = a_gk_gm0_gm1_grid_desc.GetLength(I2);
|
||||
|
||||
const auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
const auto GM10 = GM1 / GM11;
|
||||
|
||||
const auto a_gk_gm0_gm10_gm11_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
a_gk_gm0_gm1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GK),
|
||||
make_pass_through_transform(GM0),
|
||||
make_unmerge_transform(make_tuple(GM10, GM11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
return a_gk_gm0_gm10_gm11_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBGKGN0GN10GN11GridDescriptor(const BGKGN0GN1GridDesc& b_gk_gn0_gn1_grid_desc)
|
||||
{
|
||||
const auto GK = b_gk_gn0_gn1_grid_desc.GetLength(I0);
|
||||
const auto GN1 = b_gk_gn0_gn1_grid_desc.GetLength(I2);
|
||||
|
||||
const auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
const auto b_gk_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
b_gk_gn0_gn1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GK),
|
||||
make_pass_through_transform(GN0),
|
||||
make_unmerge_transform(make_tuple(GN10, GN11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
return b_gk_gn0_gn10_gn11_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(
|
||||
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
|
||||
{
|
||||
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
|
||||
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
|
||||
|
||||
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
|
||||
const auto GM10 = GM1 / GM11;
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
constexpr auto BM = GM0 * GM11;
|
||||
constexpr auto BN = GN0 * GN11;
|
||||
|
||||
constexpr auto BM1 =
|
||||
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
|
||||
constexpr auto BN1 =
|
||||
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
|
||||
|
||||
constexpr auto BM0 = BM / BM1;
|
||||
constexpr auto BN0 = BN / BN1;
|
||||
|
||||
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
c_gm0_gm1_gn0_gn1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GM0),
|
||||
make_unmerge_transform(make_tuple(GM10, GM11)),
|
||||
make_pass_through_transform(GN0),
|
||||
make_unmerge_transform(make_tuple(GN10, GN11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto c_gm10_bm_gn10_bn_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GM10),
|
||||
make_merge_transform(make_tuple(GM0, GM11)),
|
||||
make_pass_through_transform(GN10),
|
||||
make_merge_transform(make_tuple(GN0, GN11))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
c_gm10_bm_gn10_bn_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GM10),
|
||||
make_unmerge_transform(make_tuple(BM0, BM1)),
|
||||
make_pass_through_transform(GN10),
|
||||
make_unmerge_transform(make_tuple(BN0, BN1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
|
||||
|
||||
return c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeCBlockIdToGM10GN10BlockClusterAdaptor(
|
||||
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
|
||||
{
|
||||
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
|
||||
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
|
||||
|
||||
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
|
||||
const auto GM10 = GM1 / GM11;
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(GM10, GN10))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return c_blockid_to_gm10_gn10_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using AGKGM0GM10GM11GridDesc = decltype(MakeAGKGM0GM10GM11GridDescriptor(AGKGM0GM1GridDesc{}));
|
||||
using BGKGN0GN10GN11GridDesc = decltype(MakeBGKGN0GN10GN11GridDescriptor(BGKGN0GN1GridDesc{}));
|
||||
using CGM10BM0BM1GN10BN0BN1GridDesc =
|
||||
decltype(MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(CGM0GM1GN0GN1GridDesc{}));
|
||||
using CBlockIdToGM10GN10BlockClusterAdaptor =
|
||||
decltype(MakeCBlockIdToGM10GN10BlockClusterAdaptor(CGM0GM1GN0GN1GridDesc{}));
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AGKGM0GM10GM11GridDesc& a_gk_gm0_gm10_gm11_grid_desc,
|
||||
const BGKGN0GN10GN11GridDesc& b_gk_gn0_gn10_gn11_grid_desc,
|
||||
const CGM10BM0BM1GN10BN0BN1GridDesc& c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
const CBlockIdToGM10GN10BlockClusterAdaptor& c_blockid_to_gm10_gn10_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_a_grid, a_gk_gm0_gm10_gm11_grid_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_b_grid, b_gk_gn0_gn10_gn11_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_c_grid, c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetElementSpaceSize());
|
||||
|
||||
const auto GK = a_gk_gm0_gm10_gm11_grid_desc.GetLength(I0);
|
||||
|
||||
// divide block work by [GM10, GN10]
|
||||
const auto c_gm10_gn10_block_cluster_idx =
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force index data into SGPR
|
||||
const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]);
|
||||
const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]);
|
||||
|
||||
// lds max alignment
|
||||
// part of them should be moved into blockwise-gemm
|
||||
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_GM11>{},
|
||||
Number<BBlockTransferDstScalarPerVector_GN11>{},
|
||||
Number<M1PerThreadM111>{},
|
||||
Number<N1PerThreadN111>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_gk_gm0_gm10_gm11_block_desc =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_gk_gn0_gn10_gn11_block_desc =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}), max_lds_align);
|
||||
|
||||
// A matrix in LDS memory for blockwise GEMM
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_gk_bm_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory for blockwise GEMM
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_gk_bn_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}), max_lds_align);
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4<
|
||||
BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, GM0, 1, GM1PerBlockGM11>,
|
||||
ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
|
||||
ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_gk_gm0_gm10_gm11_grid_desc),
|
||||
decltype(a_gk_gm0_gm10_gm11_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
3,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_GM11,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(a_gk_gm0_gm10_gm11_grid_desc,
|
||||
make_multi_index(0, 0, igm10, 0),
|
||||
a_gk_gm0_gm10_gm11_block_desc,
|
||||
make_multi_index(0, 0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4<
|
||||
BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, GN0, 1, GN1PerBlockGN11>,
|
||||
BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
|
||||
BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_gk_gn0_gn10_gn11_grid_desc),
|
||||
decltype(b_gk_gn0_gn10_gn11_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_GN11,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_gk_gn0_gn10_gn11_grid_desc,
|
||||
make_multi_index(0, 0, ign10, 0),
|
||||
b_gk_gn0_gn10_gn11_block_desc,
|
||||
make_multi_index(0, 0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, GM1PerBlockGM11] is in LDS
|
||||
// b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS
|
||||
// c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_gk_bm_block_desc),
|
||||
decltype(b_gk_bn_block_desc),
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM1100,
|
||||
M11N11ThreadClusterN1100,
|
||||
M11N11ThreadClusterM1101,
|
||||
M11N11ThreadClusterN1101,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
constexpr auto c_bm0_bm1_bn0_bn1_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
|
||||
|
||||
constexpr auto c_bm0_bm1_bn0_bn1_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
sequence_to_tuple_of_number(c_bm0_bm1_bn0_bn1_thread_tensor_lengths));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block_double = p_shared_block;
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_bm0_bm1_bn0_bn1_thread_desc),
|
||||
decltype(c_bm0_bm1_bn0_bn1_thread_tensor_lengths)>{}
|
||||
.Run(c_bm0_bm1_bn0_bn1_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_k_m0_m1_global_iterator_hacks = AGridIteratorHacks{};
|
||||
constexpr auto b_k_n0_n1_global_iterator_hacks = BGridIteratorHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_k_m0_m1_global_move_slice_window_iterator_hack =
|
||||
AGridMoveSliceWindowIteratorHacks{};
|
||||
constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack =
|
||||
BGridMoveSliceWindowIteratorHacks{};
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_a_block_double, a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_b_block_double, b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_even_buf);
|
||||
}
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(
|
||||
a_gk_gm0_gm10_gm11_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(
|
||||
b_gk_gn0_gn10_gn11_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_bm0_bm1_bn0_bn1_thread_desc,
|
||||
a_block_even_buf,
|
||||
b_block_even_buf,
|
||||
c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(
|
||||
a_gk_gm0_gm10_gm11_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(
|
||||
b_gk_gn0_gn10_gn11_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_even_buf);
|
||||
|
||||
k_block_data_begin += 2 * KPerBlock;
|
||||
} while(k_block_data_begin < GK - 2 * KPerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_gk_gm0_gm10_gm11_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_gk_gn0_gn10_gn11_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_odd_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr index_t M11 =
|
||||
M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101;
|
||||
constexpr index_t N11 =
|
||||
N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101;
|
||||
|
||||
constexpr index_t M10 = GM1PerBlockGM11 / M11;
|
||||
constexpr index_t N10 = GN1PerBlockGN11 / N11;
|
||||
|
||||
constexpr index_t M111 = M1PerThreadM111;
|
||||
constexpr index_t N111 = N1PerThreadN111;
|
||||
|
||||
constexpr auto c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(I1,
|
||||
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0]>{},
|
||||
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1]>{},
|
||||
I1,
|
||||
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2]>{},
|
||||
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>{}));
|
||||
|
||||
const auto c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block =
|
||||
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc),
|
||||
decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc),
|
||||
Sequence<1,
|
||||
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0],
|
||||
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1],
|
||||
1,
|
||||
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2],
|
||||
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
make_multi_index(igm10,
|
||||
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I0],
|
||||
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I1],
|
||||
ign10,
|
||||
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I2],
|
||||
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I3])}
|
||||
.Run(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_grid_buf,
|
||||
CGridIteratorHacks{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -5,8 +5,8 @@
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_v2r2.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "blockwise_gemm_v2r3.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
||||
|
||||
@@ -15,10 +15,10 @@ namespace ck {
|
||||
template <typename GridwiseContraction,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AGK0GM0GM10GM11GK1GridDesc,
|
||||
typename BGK0GN0GN10GN11GK1GridDesc,
|
||||
typename CGM10BM0BM1GN10BN0BN1GridDesc,
|
||||
typename CBlockIdToGM10GN10BlockClusterAdaptor,
|
||||
typename AGridDesc_GK0_GM0_GM10_GM11_GK1,
|
||||
typename BGridDesc_GK0_GN0_GN10_GN11_GK1,
|
||||
typename CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1,
|
||||
typename CGridBlockCluster_BlockId_To_GM10_GN10,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
@@ -29,11 +29,10 @@ __global__ void
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGK0GM0GM10GM11GK1GridDesc a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
const BGK0GN0GN10GN11GK1GridDesc b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
const CGM10BM0BM1GN10BN0BN1GridDesc c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
const CBlockIdToGM10GN10BlockClusterAdaptor
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor)
|
||||
const AGridDesc_GK0_GM0_GM10_GM11_GK1 a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
const BGridDesc_GK0_GN0_GN10_GN11_GK1 b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
const CGridBlockCluster_BlockId_To_GM10_GN10 c_grid_block_cluster_blockid_to_gm10_gn10)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
@@ -44,10 +43,10 @@ __global__ void
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
@@ -57,19 +56,19 @@ template <index_t BlockSize,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AGK0GM0GM1GK1GridDesc,
|
||||
typename BGK0GN0GN1GK1GridDesc,
|
||||
typename CGM0GM1GN0GN1GridDesc,
|
||||
typename AGridDesc_GK0_GM0_GM1_GK1,
|
||||
typename BGridDesc_GK0_GN0_GN1_GK1,
|
||||
typename CGridDesc_GM0_GM1_GN0_GN1,
|
||||
index_t GM1PerBlockGM11,
|
||||
index_t GN1PerBlockGN11,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThreadM111,
|
||||
index_t N1PerThreadN111,
|
||||
index_t KPerThread,
|
||||
index_t M11N11ThreadClusterM1100,
|
||||
index_t M11N11ThreadClusterN1100,
|
||||
index_t M11N11ThreadClusterM1101,
|
||||
index_t M11N11ThreadClusterN1101,
|
||||
index_t GK0PerBlock,
|
||||
index_t BM1PerThreadBM11,
|
||||
index_t BN1PerThreadBN11,
|
||||
index_t BK0PerThread,
|
||||
index_t BM10BN10ThreadClusterBM100,
|
||||
index_t BM10BN10ThreadClusterBN100,
|
||||
index_t BM10BN10ThreadClusterBM101,
|
||||
index_t BM10BN10ThreadClusterBN101,
|
||||
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -92,7 +91,7 @@ template <index_t BlockSize,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
|
||||
struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -100,9 +99,9 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
// GM0 and GN0 need to known at compile-time
|
||||
static constexpr auto GM0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I0);
|
||||
static constexpr auto GN0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I2);
|
||||
static constexpr auto GK1 = AGK0GM0GM1GK1GridDesc{}.GetLength(I3);
|
||||
static constexpr auto GM0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I0);
|
||||
static constexpr auto GN0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I2);
|
||||
static constexpr auto GK1 = AGridDesc_GK0_GM0_GM1_GK1{}.GetLength(I3);
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
@@ -113,61 +112,62 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_gk0_gm0_gm10_gm11_gk1_block_desc =
|
||||
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
||||
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_gk0_gn0_gn10_gn11_gk1_block_desc =
|
||||
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
||||
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGK0GM0GM1GK1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc,
|
||||
const BGK0GN0GN1GK1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc,
|
||||
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
|
||||
CheckValidity(const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
|
||||
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
|
||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(GM0)>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<decltype(GN0)>>::value,
|
||||
"wrong! GM0 and GN0 need to be known at compile-time");
|
||||
|
||||
const auto GM1 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2);
|
||||
const auto GN1 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2);
|
||||
const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0);
|
||||
const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2);
|
||||
const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2);
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
|
||||
return ((GM0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I0) &&
|
||||
GM1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1) &&
|
||||
GN0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I2) &&
|
||||
GN1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3) &&
|
||||
GM0 == a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I1) &&
|
||||
GM1 == a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2) &&
|
||||
GN0 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I1) &&
|
||||
GN1 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2) &&
|
||||
GK0 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I0) &&
|
||||
GK1 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I3)) &&
|
||||
(GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % KPerBlock == 0));
|
||||
return (
|
||||
(GM0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I0) &&
|
||||
GM1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1) &&
|
||||
GN0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I2) &&
|
||||
GN1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3) &&
|
||||
GM0 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I1) &&
|
||||
GM1 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2) &&
|
||||
GN0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I1) &&
|
||||
GN1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2) &&
|
||||
GK0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0) &&
|
||||
GK1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I3)) &&
|
||||
(GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % GK0PerBlock == 0));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
|
||||
CalculateGridSize(const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
|
||||
{
|
||||
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
|
||||
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
|
||||
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
|
||||
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
|
||||
|
||||
constexpr index_t GM11 = GM1PerBlockGM11;
|
||||
constexpr index_t GN11 = GN1PerBlockGN11;
|
||||
@@ -182,29 +182,29 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK0)
|
||||
{
|
||||
const bool has_main_k_block_loop = (GK0 + KPerBlock) / (2 * KPerBlock) > 1;
|
||||
const bool has_main_k_block_loop = (GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1;
|
||||
|
||||
return has_main_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK0)
|
||||
{
|
||||
const bool has_double_tail_k_block_loop = (GK0 / KPerBlock) % 2 == 0;
|
||||
const bool has_double_tail_k_block_loop = (GK0 / GK0PerBlock) % 2 == 0;
|
||||
|
||||
return has_double_tail_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAGK0GM0GM10GM11GK1GridDescriptor(const AGK0GM0GM1GK1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc)
|
||||
__host__ __device__ static constexpr auto MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
|
||||
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1)
|
||||
{
|
||||
const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0);
|
||||
const auto GM1 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2);
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
|
||||
const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2);
|
||||
|
||||
const auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
const auto GM10 = GM1 / GM11;
|
||||
|
||||
const auto a_gk0_gm0_gm10_gm11_gk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
a_gk0_gm0_gm1_gk1_grid_desc,
|
||||
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_dynamic_tensor_descriptor(
|
||||
a_grid_desc_gk0_gm0_gm1_gk1,
|
||||
make_tuple(make_pass_through_transform(GK0),
|
||||
make_pass_through_transform(GM0),
|
||||
make_unmerge_transform(make_tuple(GM10, GM11)),
|
||||
@@ -212,20 +212,20 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||
|
||||
return a_gk0_gm0_gm10_gm11_gk1_grid_desc;
|
||||
return a_grid_desc_gk0_gm0_gm10_gm11_gk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBGK0GN0GN10GN11GK1GridDescriptor(const BGK0GN0GN1GK1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc)
|
||||
__host__ __device__ static constexpr auto MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
|
||||
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1)
|
||||
{
|
||||
const auto GK0 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I0);
|
||||
const auto GN1 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2);
|
||||
const auto GK0 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0);
|
||||
const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2);
|
||||
|
||||
const auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
const auto b_gk0_gn0_gn10_gn11_gk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
b_gk0_gn0_gn1_gk1_grid_desc,
|
||||
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_dynamic_tensor_descriptor(
|
||||
b_grid_desc_gk0_gn0_gn1_gk1,
|
||||
make_tuple(make_pass_through_transform(GK0),
|
||||
make_pass_through_transform(GN0),
|
||||
make_unmerge_transform(make_tuple(GN10, GN11)),
|
||||
@@ -233,14 +233,14 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||
|
||||
return b_gk0_gn0_gn10_gn11_gk1_grid_desc;
|
||||
return b_grid_desc_gk0_gn0_gn10_gn11_gk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(
|
||||
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
|
||||
__host__ __device__ static constexpr auto MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
|
||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
|
||||
{
|
||||
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
|
||||
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
|
||||
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
|
||||
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
|
||||
|
||||
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
@@ -252,15 +252,15 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
|
||||
constexpr auto BN = GN0 * GN11;
|
||||
|
||||
constexpr auto BM1 =
|
||||
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
|
||||
Number<BM10BN10ThreadClusterBM100 * BM10BN10ThreadClusterBM101 * BM1PerThreadBM11>{};
|
||||
constexpr auto BN1 =
|
||||
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
|
||||
Number<BM10BN10ThreadClusterBN100 * BM10BN10ThreadClusterBN101 * BN1PerThreadBN11>{};
|
||||
|
||||
constexpr auto BM0 = BM / BM1;
|
||||
constexpr auto BN0 = BN / BN1;
|
||||
|
||||
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
c_gm0_gm1_gn0_gn1_grid_desc,
|
||||
c_grid_desc_gm0_gm1_gn0_gn1,
|
||||
make_tuple(make_pass_through_transform(GM0),
|
||||
make_unmerge_transform(make_tuple(GM10, GM11)),
|
||||
make_pass_through_transform(GN0),
|
||||
@@ -277,7 +277,7 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_dynamic_tensor_descriptor(
|
||||
c_gm10_bm_gn10_bn_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GM10),
|
||||
make_unmerge_transform(make_tuple(BM0, BM1)),
|
||||
@@ -286,14 +286,14 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
|
||||
|
||||
return c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc;
|
||||
return c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeCBlockIdToGM10GN10BlockClusterAdaptor(
|
||||
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
|
||||
__host__ __device__ static constexpr auto MakeCGridBlockCluster_BlockId_To_GM10_GN10(
|
||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
|
||||
{
|
||||
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
|
||||
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
|
||||
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
|
||||
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
|
||||
|
||||
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
@@ -301,22 +301,22 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
|
||||
const auto GM10 = GM1 / GM11;
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor = make_single_stage_tensor_adaptor(
|
||||
const auto c_grid_block_cluster_blockid_to_gm10_gn10 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(GM10, GN10))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return c_blockid_to_gm10_gn10_block_cluster_adaptor;
|
||||
return c_grid_block_cluster_blockid_to_gm10_gn10;
|
||||
}
|
||||
|
||||
using AGK0GM0GM10GM11GK1GridDesc =
|
||||
decltype(MakeAGK0GM0GM10GM11GK1GridDescriptor(AGK0GM0GM1GK1GridDesc{}));
|
||||
using BGK0GN0GN10GN11GK1GridDesc =
|
||||
decltype(MakeBGK0GN0GN10GN11GK1GridDescriptor(BGK0GN0GN1GK1GridDesc{}));
|
||||
using CGM10BM0BM1GN10BN0BN1GridDesc =
|
||||
decltype(MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(CGM0GM1GN0GN1GridDesc{}));
|
||||
using CBlockIdToGM10GN10BlockClusterAdaptor =
|
||||
decltype(MakeCBlockIdToGM10GN10BlockClusterAdaptor(CGM0GM1GN0GN1GridDesc{}));
|
||||
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
|
||||
decltype(MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(AGridDesc_GK0_GM0_GM1_GK1{}));
|
||||
using BGridDesc_GK0_GN0_GN10_GN11_GK1 =
|
||||
decltype(MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(BGridDesc_GK0_GN0_GN1_GK1{}));
|
||||
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 =
|
||||
decltype(MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(CGridDesc_GM0_GM1_GN0_GN1{}));
|
||||
using CGridBlockCluster_BlockId_To_GM10_GN10 =
|
||||
decltype(MakeCGridBlockCluster_BlockId_To_GM10_GN10(CGridDesc_GM0_GM1_GN0_GN1{}));
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void
|
||||
@@ -324,25 +324,25 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AGK0GM0GM10GM11GK1GridDesc& a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
const BGK0GN0GN10GN11GK1GridDesc& b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
const CGM10BM0BM1GN10BN0BN1GridDesc& c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
const CBlockIdToGM10GN10BlockClusterAdaptor& c_blockid_to_gm10_gn10_block_cluster_adaptor,
|
||||
const AGridDesc_GK0_GM0_GM10_GM11_GK1& a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
const BGridDesc_GK0_GN0_GN10_GN11_GK1& b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1& c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
const CGridBlockCluster_BlockId_To_GM10_GN10& c_grid_block_cluster_blockid_to_gm10_gn10,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_a_grid, a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetElementSpaceSize());
|
||||
p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_b_grid, b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetElementSpaceSize());
|
||||
p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_c_grid, c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetElementSpaceSize());
|
||||
p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize());
|
||||
|
||||
const auto GK0 = a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I0);
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0);
|
||||
|
||||
// divide block work by [GM10, GN10]
|
||||
const auto c_gm10_gn10_block_cluster_idx =
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor.CalculateBottomIndex(
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10.CalculateBottomIndex(
|
||||
make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force index data into SGPR
|
||||
@@ -356,46 +356,46 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_gk0_gm0_gm10_gm11_gk1_block_desc =
|
||||
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
||||
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_gk0_gn0_gn10_gn11_gk1_block_desc =
|
||||
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
||||
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// A matrix in LDS memory for blockwise GEMM
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_gk0_bm_gk1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
|
||||
constexpr auto a_block_desc_gk0_bm_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<GK0PerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory for blockwise GEMM
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_gk0_bn_gk1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
|
||||
constexpr auto b_block_desc_gk0_bn_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<GK0PerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
|
||||
|
||||
static_assert(a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize() ==
|
||||
a_gk0_bm_gk1_block_desc.GetElementSpaceSize() &&
|
||||
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize() ==
|
||||
b_gk0_bn_gk1_block_desc.GetElementSpaceSize(),
|
||||
static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() ==
|
||||
a_block_desc_gk0_bm_gk1.GetElementSpaceSize() &&
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize() ==
|
||||
b_block_desc_gk0_bn_gk1.GetElementSpaceSize(),
|
||||
"wrong!");
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
|
||||
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_gk0_gm0_gm10_gm11_gk1_grid_desc),
|
||||
decltype(a_gk0_gm0_gm10_gm11_gk1_block_desc),
|
||||
decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1),
|
||||
decltype(a_block_desc_gk0_gm0_gm10_gm11_gk1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // SrcVectorTensorLengths
|
||||
@@ -403,23 +403,23 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
true>(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
make_multi_index(0, 0, igm10, 0, 0),
|
||||
a_gk0_gm0_gm10_gm11_gk1_block_desc,
|
||||
a_block_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
make_multi_index(0, 0, 0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
|
||||
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_gk0_gn0_gn10_gn11_gk1_grid_desc),
|
||||
decltype(b_gk0_gn0_gn10_gn11_gk1_block_desc),
|
||||
decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1),
|
||||
decltype(b_block_desc_gk0_gn0_gn10_gn11_gk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // SrcVectorTensorLengths
|
||||
@@ -427,102 +427,103 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
true>(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
make_multi_index(0, 0, ign10, 0, 0),
|
||||
b_gk0_gn0_gn10_gn11_gk1_block_desc,
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
make_multi_index(0, 0, 0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, GM1PerBlockGM11] is in LDS
|
||||
// a_mtx[GK0PerBlock, GM1PerBlockGM11] is in LDS
|
||||
// b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS
|
||||
// c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_gk0_bm_gk1_block_desc),
|
||||
decltype(b_gk0_bn_gk1_block_desc),
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM1100,
|
||||
M11N11ThreadClusterN1100,
|
||||
M11N11ThreadClusterM1101,
|
||||
M11N11ThreadClusterN1101,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_gk0_bm_gk1),
|
||||
decltype(b_block_desc_gk0_bn_gk1),
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM100,
|
||||
BM10BN10ThreadClusterBN100,
|
||||
BM10BN10ThreadClusterBM101,
|
||||
BM10BN10ThreadClusterBN101,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11>{};
|
||||
|
||||
constexpr auto c_bm0_bm1_bn0_bn1_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
|
||||
constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 =
|
||||
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
|
||||
|
||||
constexpr auto c_bm0_bm1_bn0_bn1_thread_desc =
|
||||
constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
sequence_to_tuple_of_number(c_bm0_bm1_bn0_bn1_thread_tensor_lengths));
|
||||
sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block_double = p_shared_block;
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc.GetElementSpaceSize());
|
||||
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_bm0_bm1_bn0_bn1_thread_desc),
|
||||
decltype(c_bm0_bm1_bn0_bn1_thread_tensor_lengths)>{}
|
||||
.Run(c_bm0_bm1_bn0_bn1_thread_desc,
|
||||
decltype(c_thread_desc_bm0_bm1_bn0_bn1),
|
||||
decltype(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)>{}
|
||||
.Run(c_thread_desc_bm0_bm1_bn0_bn1,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0, 0);
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_a_block_double, a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize());
|
||||
p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_b_block_double, b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize());
|
||||
p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize());
|
||||
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize());
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
||||
|
||||
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_even_buf);
|
||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
|
||||
}
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t k_block_data_begin = 0;
|
||||
index_t gk0_block_on_grid = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
|
||||
@@ -530,25 +531,25 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_bm0_bm1_bn0_bn1_thread_desc,
|
||||
blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1,
|
||||
a_block_even_buf,
|
||||
b_block_even_buf,
|
||||
c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_odd_buf);
|
||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
|
||||
@@ -556,29 +557,29 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_even_buf);
|
||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
|
||||
|
||||
k_block_data_begin += 2 * KPerBlock;
|
||||
} while(k_block_data_begin < GK0 - 2 * KPerBlock);
|
||||
gk0_block_on_grid += 2 * GK0PerBlock;
|
||||
} while(gk0_block_on_grid < GK0 - 2 * GK0PerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
|
||||
@@ -586,23 +587,23 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_odd_buf);
|
||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
@@ -610,61 +611,51 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr index_t M11 =
|
||||
M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101;
|
||||
constexpr index_t N11 =
|
||||
N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101;
|
||||
|
||||
constexpr index_t M10 = GM1PerBlockGM11 / M11;
|
||||
constexpr index_t N10 = GN1PerBlockGN11 / N11;
|
||||
|
||||
constexpr index_t M111 = M1PerThreadM111;
|
||||
constexpr index_t N111 = N1PerThreadN111;
|
||||
|
||||
constexpr auto c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc =
|
||||
constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(I1,
|
||||
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0]>{},
|
||||
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1]>{},
|
||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0]>{},
|
||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1]>{},
|
||||
I1,
|
||||
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2]>{},
|
||||
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>{}));
|
||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2]>{},
|
||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>{}));
|
||||
|
||||
const auto c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block =
|
||||
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
|
||||
const auto c_thread_origin_on_block_bm0_bm1_bn0_bn1 =
|
||||
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||
get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc),
|
||||
decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc),
|
||||
decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1),
|
||||
decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1),
|
||||
Sequence<1,
|
||||
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0],
|
||||
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1],
|
||||
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0],
|
||||
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1],
|
||||
1,
|
||||
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2],
|
||||
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>,
|
||||
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2],
|
||||
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
make_multi_index(igm10,
|
||||
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I0],
|
||||
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I1],
|
||||
ign10,
|
||||
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I2],
|
||||
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I3])}
|
||||
.Run(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc,
|
||||
false>{c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
make_multi_index(igm10,
|
||||
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I0],
|
||||
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I1],
|
||||
ign10,
|
||||
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I2],
|
||||
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I3])}
|
||||
.Run(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_buf,
|
||||
CGridIteratorHacks{});
|
||||
}
|
||||
|
||||
@@ -1,552 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_HPP
|
||||
#define CK_GRIDWISE_DYNAMIC_GEMM_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_v2.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
typename CBlockClusterDesc,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_v1r1(const FloatA* __restrict__ p_a_global,
|
||||
const FloatB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const AGlobalDesc a_k_m_global_desc,
|
||||
const BGlobalDesc b_k_n_global_desc,
|
||||
const CGlobalDesc c_m0_m1_n0_n1_global_desc,
|
||||
const CBlockClusterDesc c_block_cluster_desc)
|
||||
{
|
||||
GridwiseGemm::Run(p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
// pass tensor descriptor by __CONSTANT__ void pointer
|
||||
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
|
||||
// non-modifiable parameter address space, so compiler can enable corresponding optimization
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
typename CBlockClusterDesc,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_v1r1(const FloatA* __restrict__ p_a_global,
|
||||
const FloatB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const void __CONSTANT__* p_a_k_m_global_desc,
|
||||
const void __CONSTANT__* p_b_k_n_global_desc,
|
||||
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
|
||||
const void __CONSTANT__* p_c_block_cluster_desc)
|
||||
{
|
||||
// first cast void __CONSTANT__ void* to void*
|
||||
// second cast void* to Desc*
|
||||
// the copy constructor of tensor descriptor doesn't take address_space(4)
|
||||
const auto a_k_m_global_desc =
|
||||
*reinterpret_cast<const AGlobalDesc*>((const void*)p_a_k_m_global_desc);
|
||||
const auto b_k_n_global_desc =
|
||||
*reinterpret_cast<const BGlobalDesc*>((const void*)p_b_k_n_global_desc);
|
||||
const auto c_m0_m1_n0_n1_global_desc =
|
||||
*reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_n0_n1_global_desc);
|
||||
|
||||
const auto c_block_cluster_desc =
|
||||
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
|
||||
|
||||
GridwiseGemm::Run(p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
#endif
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
typename CBlockClusterDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM10,
|
||||
index_t M1N1ThreadClusterN10,
|
||||
index_t M1N1ThreadClusterM11,
|
||||
index_t M1N1ThreadClusterN11,
|
||||
typename ABlockTransferThreadSliceLengths_K_M,
|
||||
typename ABlockTransferThreadClusterLengths_K_M,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_M,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K_N,
|
||||
typename BBlockTransferThreadClusterLengths_K_N,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_N,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGlobalIteratorHacks,
|
||||
typename BGlobalIteratorHacks,
|
||||
typename CGlobalIteratorHacks,
|
||||
typename AGlobalMoveSliceWindowIteratorHacks,
|
||||
typename BGlobalMoveSliceWindowIteratorHacks>
|
||||
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
|
||||
Number<BBlockTransferDstScalarPerVector_N>{},
|
||||
Number<M1PerThread>{},
|
||||
Number<N1PerThread>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const AGlobalDesc& a_k_m_global_desc,
|
||||
const BGlobalDesc& b_k_n_global_desc,
|
||||
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
|
||||
const CBlockClusterDesc& c_block_cluster_desc,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_a_global, a_k_m_global_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_b_global, b_k_n_global_desc.GetElementSpaceSize());
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_c_global, c_m0_m1_n0_n1_global_desc.GetElementSpaceSize());
|
||||
|
||||
const auto K = a_k_m_global_desc.GetLength(I0);
|
||||
const auto M = a_k_m_global_desc.GetLength(I1);
|
||||
const auto N = b_k_n_global_desc.GetLength(I1);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_global into SGPR
|
||||
const index_t m_block_data_idx_on_global =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_global =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
|
||||
Number<BBlockTransferDstScalarPerVector_N>{},
|
||||
Number<M1PerThread>{},
|
||||
Number<N1PerThread>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, MPerBlock>,
|
||||
ABlockTransferThreadSliceLengths_K_M,
|
||||
ABlockTransferThreadClusterLengths_K_M,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k_m_global_desc),
|
||||
decltype(a_k_m_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
a_k_m_global_desc,
|
||||
make_multi_index(0, m_block_data_idx_on_global),
|
||||
a_k_m_block_desc,
|
||||
make_multi_index(0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, NPerBlock>,
|
||||
BBlockTransferThreadSliceLengths_K_N,
|
||||
BBlockTransferThreadClusterLengths_K_N,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k_n_global_desc),
|
||||
decltype(b_k_n_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_k_n_global_desc,
|
||||
make_multi_index(0, n_block_data_idx_on_global),
|
||||
b_k_n_block_desc,
|
||||
make_multi_index(0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[KPerBlocl, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
static_assert(
|
||||
MPerBlock % (M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10) == 0 &&
|
||||
NPerBlock % (N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t M0PerThread =
|
||||
MPerBlock / (M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10);
|
||||
constexpr index_t N0PerThread =
|
||||
NPerBlock / (N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10);
|
||||
|
||||
constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
a_k_m_block_desc,
|
||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<M0PerThread>{},
|
||||
Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
b_k_n_block_desc,
|
||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<N0PerThread>{},
|
||||
Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
constexpr auto c_m0_m1_n0_n1_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<M0PerThread>{},
|
||||
Number<M1PerThread>{},
|
||||
Number<N0PerThread>{},
|
||||
Number<N1PerThread>{}));
|
||||
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k_m0_m1_block_desc),
|
||||
decltype(b_k_n0_n1_block_desc),
|
||||
decltype(c_m0_m1_n0_n1_thread_desc),
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
M1PerThread,
|
||||
N1PerThread>{};
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block_double = p_shared_block;
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
|
||||
c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseDynamicTensorSliceSet_v1<
|
||||
FloatAcc,
|
||||
decltype(c_m0_m1_n0_n1_thread_desc),
|
||||
Sequence<M0PerThread, M1PerThread, N0PerThread, N1PerThread>>{}
|
||||
.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
|
||||
constexpr auto b_k_n_global_iterator_hacks = BGlobalIteratorHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_k_m_global_move_slice_window_iterator_hack =
|
||||
AGlobalMoveSliceWindowIteratorHacks{};
|
||||
constexpr auto b_k_n_global_move_slice_window_iterator_hack =
|
||||
BGlobalMoveSliceWindowIteratorHacks{};
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_a_block_double, a_k_m_block_desc.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_b_block_double, b_k_n_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_a_block_double + a_block_space_size, a_k_m_block_desc.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_b_block_double + b_block_space_size, b_k_n_block_desc.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
|
||||
}
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n_global_move_slice_window_iterator_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n_global_move_slice_window_iterator_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
|
||||
|
||||
k_block_data_begin += 2 * KPerBlock;
|
||||
} while(k_block_data_begin < K - 2 * KPerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n_global_move_slice_window_iterator_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto M1 = Number<M1PerThread * M1N1ThreadClusterM10 * M1N1ThreadClusterM11>{};
|
||||
constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN10 * M1N1ThreadClusterN11>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
|
||||
|
||||
const auto c_thread_data_idx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_m1_n0_n1_thread_desc),
|
||||
decltype(c_m0_m1_n0_n1_global_desc),
|
||||
Sequence<M0PerThread, M1PerThread, N0PerThread, N1PerThread>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
make_multi_index(m_block_data_idx_on_global / M1 + c_thread_data_idx_on_block[I0],
|
||||
c_thread_data_idx_on_block[I1],
|
||||
n_block_data_idx_on_global / N1 + c_thread_data_idx_on_block[I2],
|
||||
c_thread_data_idx_on_block[I3])}
|
||||
.Run(c_m0_m1_n0_n1_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_global_buf,
|
||||
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const AGlobalDesc& a_k_m_global_desc,
|
||||
const BGlobalDesc& b_k_n_global_desc,
|
||||
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
|
||||
const CBlockClusterDesc& c_block_cluster_desc,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
Run(p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc,
|
||||
p_shared_block,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -435,21 +435,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
|
||||
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM1100,
|
||||
M11N11ThreadClusterN1100,
|
||||
M11N11ThreadClusterM1101,
|
||||
M11N11ThreadClusterN1101,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM1100,
|
||||
M11N11ThreadClusterN1100,
|
||||
M11N11ThreadClusterM1101,
|
||||
M11N11ThreadClusterN1101,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
|
||||
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
|
||||
|
||||
@@ -1,40 +1,44 @@
|
||||
#ifndef CK_THREADWISE_GEMM_V2_HPP
|
||||
#define CK_THREADWISE_GEMM_V2_HPP
|
||||
#ifndef CK_THREADWISE_CONTRACTION_HPP
|
||||
#define CK_THREADWISE_CONTRACTION_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "math.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// C[M0, M1, N0, N1] += A[K, M0, M1] * B[K, N0, N1]
|
||||
// C[TM0, TM1, TN0, TN1] += A[TK, TM0, TM1] * B[TK, TN0, TN1]
|
||||
// Tensor element can be vectorized data
|
||||
// Assume:
|
||||
// 1. ADesc, BDesc, CDesc are known at compile-time
|
||||
// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are
|
||||
// known at compile-time
|
||||
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
|
||||
template <typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename ADesc,
|
||||
typename BDesc,
|
||||
typename CDesc,
|
||||
typename KLengths,
|
||||
typename MLengths,
|
||||
typename NLengths,
|
||||
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
typename AThreadDesc_TK0_TM0_TM1_TK1,
|
||||
typename BThreadDesc_TK0_TN0_TN1_TK1,
|
||||
typename CThreadDesc_TM0_TM1_TN0_TN1,
|
||||
typename TKLengths,
|
||||
typename TMLengths,
|
||||
typename TNLengths,
|
||||
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
|
||||
{
|
||||
__device__ constexpr ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1()
|
||||
{
|
||||
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
// TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths
|
||||
// TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1,
|
||||
// CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths
|
||||
|
||||
// TODO remove this restriction
|
||||
static_assert(KLengths::Size() == 1 && MLengths::Size() == 2 && NLengths::Size() == 2,
|
||||
static_assert(TKLengths::Size() == 1 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
@@ -70,28 +74,31 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto K = KLengths{}[I0];
|
||||
constexpr auto M0 = MLengths{}[I0];
|
||||
constexpr auto M1 = MLengths{}[I1];
|
||||
constexpr auto N0 = NLengths{}[I0];
|
||||
constexpr auto N1 = NLengths{}[I1];
|
||||
constexpr auto TK = TKLengths{}[I0];
|
||||
constexpr auto TM0 = TMLengths{}[I0];
|
||||
constexpr auto TM1 = TMLengths{}[I1];
|
||||
constexpr auto TN0 = TNLengths{}[I0];
|
||||
constexpr auto TN1 = TNLengths{}[I1];
|
||||
|
||||
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
|
||||
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
|
||||
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
|
||||
|
||||
static_for<0, K, 1>{}([&](auto k) {
|
||||
static_for<0, M0, 1>{}([&](auto m0) {
|
||||
static_for<0, M1, 1>{}([&](auto m1) {
|
||||
static_for<0, N0, 1>{}([&](auto n0) {
|
||||
static_for<0, N1, 1>{}([&](auto n1) {
|
||||
static_for<0, TK, 1>{}([&](auto tk) {
|
||||
static_for<0, TM0, 1>{}([&](auto tm0) {
|
||||
static_for<0, TM1, 1>{}([&](auto tm1) {
|
||||
static_for<0, TN0, 1>{}([&](auto tn0) {
|
||||
static_for<0, TN1, 1>{}([&](auto tn1) {
|
||||
|
||||
constexpr index_t a_offset =
|
||||
ADesc{}.CalculateOffset(a_origin_idx + make_multi_index(k, m0, m1));
|
||||
AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
|
||||
a_origin_idx + make_multi_index(tk, tm0, tm1));
|
||||
constexpr index_t b_offset =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_multi_index(k, n0, n1));
|
||||
constexpr index_t c_offset = CDesc{}.CalculateOffset(
|
||||
c_origin_idx + make_multi_index(m0, m1, n0, n1));
|
||||
BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
|
||||
b_origin_idx + make_multi_index(tk, tn0, tn1));
|
||||
constexpr index_t c_offset =
|
||||
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
|
||||
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
|
||||
|
||||
amd_inner_product_dlop<FloatA, FloatB, FloatC>(
|
||||
a_buf[Number<a_offset>{}],
|
||||
@@ -105,35 +112,39 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
|
||||
}
|
||||
};
|
||||
|
||||
// C[M0, M1, N0, N1] += A[K0, M0, M1, K1] * B[K0, N0, N1, K1]
|
||||
// C[TM0, TM1, TN0, TN1] += A[TK0, TM0, TM1, TK1] * B[TK0, TN0, TN1, TK1]
|
||||
// Tensor element can be vectorized data
|
||||
// Assume:
|
||||
// 1. ADesc, BDesc, CDesc are known at compile-time
|
||||
// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are
|
||||
// known at compile-time
|
||||
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
|
||||
template <typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename ADesc,
|
||||
typename BDesc,
|
||||
typename CDesc,
|
||||
typename KLengths,
|
||||
typename MLengths,
|
||||
typename NLengths,
|
||||
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
typename AThreadDesc_TK0_TM0_TM1_TK1,
|
||||
typename BThreadDesc_TK0_TN0_TN1_TK1,
|
||||
typename CThreadDesc_TM0_TM1_TN0_TN1,
|
||||
typename TKLengths,
|
||||
typename TMLengths,
|
||||
typename TNLengths,
|
||||
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1
|
||||
struct ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
|
||||
{
|
||||
__device__ constexpr ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1()
|
||||
__device__ constexpr ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
|
||||
{
|
||||
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
// TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths
|
||||
// TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1,
|
||||
// CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths
|
||||
|
||||
// TODO remove this restriction
|
||||
static_assert(KLengths::Size() == 2 && MLengths::Size() == 2 && NLengths::Size() == 2,
|
||||
static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
@@ -169,43 +180,45 @@ struct ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t K0 = KLengths{}[I0];
|
||||
constexpr index_t K1 = KLengths{}[I1];
|
||||
constexpr index_t M0 = MLengths{}[I0];
|
||||
constexpr index_t M1 = MLengths{}[I1];
|
||||
constexpr index_t N0 = NLengths{}[I0];
|
||||
constexpr index_t N1 = NLengths{}[I1];
|
||||
constexpr index_t TK0 = TKLengths{}[I0];
|
||||
constexpr index_t TK1 = TKLengths{}[I1];
|
||||
constexpr index_t TM0 = TMLengths{}[I0];
|
||||
constexpr index_t TM1 = TMLengths{}[I1];
|
||||
constexpr index_t TN0 = TNLengths{}[I0];
|
||||
constexpr index_t TN1 = TNLengths{}[I1];
|
||||
|
||||
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
|
||||
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
|
||||
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
|
||||
|
||||
static_for<0, K0, 1>{}([&](auto k0) {
|
||||
static_for<0, M0, 1>{}([&](auto m0) {
|
||||
static_for<0, M1, 1>{}([&](auto m1) {
|
||||
static_for<0, N0, 1>{}([&](auto n0) {
|
||||
static_for<0, N1, 1>{}([&](auto n1) {
|
||||
static_for<0, TK0, 1>{}([&](auto tk0) {
|
||||
static_for<0, TM0, 1>{}([&](auto tm0) {
|
||||
static_for<0, TM1, 1>{}([&](auto tm1) {
|
||||
static_for<0, TN0, 1>{}([&](auto tn0) {
|
||||
static_for<0, TN1, 1>{}([&](auto tn1) {
|
||||
|
||||
vector_type<FloatA, K1> a_vec;
|
||||
vector_type<FloatB, K1> b_vec;
|
||||
vector_type<FloatA, TK1> a_vec;
|
||||
vector_type<FloatB, TK1> b_vec;
|
||||
|
||||
static_for<0, K1, 1>{}([&](auto k1) {
|
||||
constexpr index_t a_offset = ADesc{}.CalculateOffset(
|
||||
a_origin_idx + make_multi_index(k0, m0, m1, k1));
|
||||
static_for<0, TK1, 1>{}([&](auto tk1) {
|
||||
constexpr index_t a_offset =
|
||||
AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
|
||||
a_origin_idx + make_multi_index(tk0, tm0, tm1, tk1));
|
||||
|
||||
constexpr index_t b_offset = BDesc{}.CalculateOffset(
|
||||
b_origin_idx + make_multi_index(k0, n0, n1, k1));
|
||||
constexpr index_t b_offset =
|
||||
BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
|
||||
b_origin_idx + make_multi_index(tk0, tn0, tn1, tk1));
|
||||
|
||||
a_vec.template AsType<FloatA>()(k1) = a_buf[Number<a_offset>{}];
|
||||
|
||||
b_vec.template AsType<FloatB>()(k1) = b_buf[Number<b_offset>{}];
|
||||
a_vec.template AsType<FloatA>()(tk1) = a_buf[Number<a_offset>{}];
|
||||
b_vec.template AsType<FloatB>()(tk1) = b_buf[Number<b_offset>{}];
|
||||
});
|
||||
|
||||
using a_vector_t = typename vector_type<FloatA, K1>::type;
|
||||
using b_vector_t = typename vector_type<FloatB, K1>::type;
|
||||
using a_vector_t = typename vector_type<FloatA, TK1>::type;
|
||||
using b_vector_t = typename vector_type<FloatB, TK1>::type;
|
||||
|
||||
constexpr index_t c_offset = CDesc{}.CalculateOffset(
|
||||
c_origin_idx + make_multi_index(m0, m1, n0, n1));
|
||||
constexpr index_t c_offset =
|
||||
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
|
||||
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
|
||||
|
||||
amd_inner_product_dlop<a_vector_t, b_vector_t, FloatC>(
|
||||
a_vec.template AsType<a_vector_t>()[I0],
|
||||
@@ -1,379 +0,0 @@
|
||||
#include "common_header.hpp"
|
||||
#include "type_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_contraction_v1r1.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using FloatAB = typename get_type_from_type_id<static_cast<char>(CK_PARAM_IN_WEI_DATATYPE)>::type;
|
||||
using FloatC = typename get_type_from_type_id<static_cast<char>(CK_PARAM_OUT_DATATYPE)>::type;
|
||||
using FloatAcc = typename get_type_from_type_id<static_cast<char>(CK_PARAM_CONV_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
constexpr index_t N0 = CK_PARAM_N0;
|
||||
|
||||
constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11;
|
||||
constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11;
|
||||
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
||||
constexpr index_t M1PerThread = CK_PARAM_M1PerThread;
|
||||
constexpr index_t N1PerThread = CK_PARAM_N1PerThread;
|
||||
constexpr index_t KPerThread = CK_PARAM_KPerThread;
|
||||
constexpr index_t M1N1ThreadClusterM10 = CK_PARAM_M1N1ThreadClusterM10;
|
||||
constexpr index_t M1N1ThreadClusterN10 = CK_PARAM_M1N1ThreadClusterN10;
|
||||
constexpr index_t M1N1ThreadClusterM11 = CK_PARAM_M1N1ThreadClusterM11;
|
||||
constexpr index_t M1N1ThreadClusterN11 = CK_PARAM_M1N1ThreadClusterN11;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11>;
|
||||
using ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
||||
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_GM11 =
|
||||
CK_PARAM_ABlockTransferDstScalarPerVector_GM11;
|
||||
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11>;
|
||||
using BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
||||
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_GN11 =
|
||||
CK_PARAM_BBlockTransferDstScalarPerVector_GN11;
|
||||
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
||||
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||
|
||||
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP);
|
||||
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP);
|
||||
|
||||
extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw_prepare(
|
||||
int n,
|
||||
int c,
|
||||
int hi,
|
||||
int wi,
|
||||
int k,
|
||||
int y,
|
||||
int x,
|
||||
int convStrideH,
|
||||
int convStrideW,
|
||||
int convDilationY,
|
||||
int convDilationX,
|
||||
int leftPadH,
|
||||
int leftPadW,
|
||||
int rightPadH,
|
||||
int rightPadW,
|
||||
void* p_a_gk_gm0_gm10_gm11_grid_desc,
|
||||
void* p_b_gk_gn0_gn10_gn11_grid_desc,
|
||||
void* p_c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
void* p_c_blockid_to_gm10_gn10_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
||||
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, c, hi, wi));
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, c, y, x));
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, k, ho, wo));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad<N0>(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(convStrideH, convStrideW),
|
||||
make_tuple(convDilationY, convDilationX),
|
||||
make_tuple(leftPadH, leftPadW),
|
||||
make_tuple(rightPadH, rightPadW));
|
||||
|
||||
const auto a_gk_gm0_gm1_grid_desc = descs[I0];
|
||||
const auto b_gk_gn0_gn1_grid_desc = descs[I1];
|
||||
const auto c_gm0_gm1_gn0_gn1_grid_desc = descs[I2];
|
||||
|
||||
using AGKGM0GM1GridDesc = decltype(a_gk_gm0_gm1_grid_desc);
|
||||
using BGKGN0GN1GridDesc = decltype(b_gk_gn0_gn1_grid_desc);
|
||||
using CGM0GM1GN0GN1GridDesc = decltype(c_gm0_gm1_gn0_gn1_grid_desc);
|
||||
|
||||
using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridIteratorHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0>;
|
||||
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0>;
|
||||
|
||||
using GridwiseContraction = GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set, /* ToDo tunable */
|
||||
AGKGM0GM1GridDesc,
|
||||
BGKGN0GN1GridDesc,
|
||||
CGM0GM1GN0GN1GridDesc,
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
|
||||
ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_GM11,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
|
||||
BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_GN11,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
auto a_gk_gm0_gm10_gm11_grid_desc =
|
||||
GridwiseContraction::MakeAGKGM0GM10GM11GridDescriptor(a_gk_gm0_gm1_grid_desc);
|
||||
auto b_gk_gn0_gn10_gn11_grid_desc =
|
||||
GridwiseContraction::MakeBGKGN0GN10GN11GridDescriptor(b_gk_gn0_gn1_grid_desc);
|
||||
auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc =
|
||||
GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc);
|
||||
auto c_blockid_to_gm10_gn10_block_cluster_adaptor =
|
||||
GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(c_gm0_gm1_gn0_gn1_grid_desc);
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
{
|
||||
*static_cast<decltype(a_gk_gm0_gm10_gm11_grid_desc)*>(p_a_gk_gm0_gm10_gm11_grid_desc) =
|
||||
a_gk_gm0_gm10_gm11_grid_desc;
|
||||
*static_cast<decltype(b_gk_gn0_gn10_gn11_grid_desc)*>(p_b_gk_gn0_gn10_gn11_grid_desc) =
|
||||
b_gk_gn0_gn10_gn11_grid_desc;
|
||||
*static_cast<decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc)*>(
|
||||
p_c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc) = c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc;
|
||||
*static_cast<decltype(c_blockid_to_gm10_gn10_block_cluster_adaptor)*>(
|
||||
p_c_blockid_to_gm10_gn10_block_cluster_adaptor) =
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor;
|
||||
};
|
||||
};
|
||||
|
||||
extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void __CONSTANT__* p_a_gk_gm0_gm10_gm11_grid_desc,
|
||||
const void __CONSTANT__* p_b_gk_gn0_gn10_gn11_grid_desc,
|
||||
const void __CONSTANT__* p_c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
const void __CONSTANT__* p_c_blockid_to_gm10_gn10_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
|
||||
constexpr auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3));
|
||||
constexpr auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad<N0>(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1));
|
||||
|
||||
constexpr auto a_gk_gm0_gm1_grid_desc = descs[I0];
|
||||
constexpr auto b_gk_gn0_gn1_grid_desc = descs[I1];
|
||||
constexpr auto c_gm0_gm1_gn0_gn1_grid_desc = descs[I2];
|
||||
|
||||
using AGKGM0GM1GridDesc = decltype(a_gk_gm0_gm1_grid_desc);
|
||||
using BGKGN0GN1GridDesc = decltype(b_gk_gn0_gn1_grid_desc);
|
||||
using CGM0GM1GN0GN1GridDesc = decltype(c_gm0_gm1_gn0_gn1_grid_desc);
|
||||
|
||||
using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridIteratorHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0>;
|
||||
|
||||
using GridwiseContraction = GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set, /* ToDo tunable */
|
||||
AGKGM0GM1GridDesc,
|
||||
BGKGN0GN1GridDesc,
|
||||
CGM0GM1GN0GN1GridDesc,
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
|
||||
ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_GM11,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
|
||||
BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_GN11,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
using AGKGM0GM10GM11GridDesc =
|
||||
decltype(GridwiseContraction::MakeAGKGM0GM10GM11GridDescriptor(a_gk_gm0_gm1_grid_desc));
|
||||
using BGKGN0GN10GN11GridDesc =
|
||||
decltype(GridwiseContraction::MakeBGKGN0GN10GN11GridDescriptor(b_gk_gn0_gn1_grid_desc));
|
||||
using CGM10BM0BM1GN10BN0BN1GridDesc = decltype(
|
||||
GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc));
|
||||
using CBlockIdToGM10GN10BlockClusterAdaptor =
|
||||
decltype(GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(
|
||||
c_gm0_gm1_gn0_gn1_grid_desc));
|
||||
|
||||
const auto a_gk_gm0_gm10_gm11_grid_desc = *reinterpret_cast<const AGKGM0GM10GM11GridDesc*>(
|
||||
(const void*)p_a_gk_gm0_gm10_gm11_grid_desc);
|
||||
const auto b_gk_gn0_gn10_gn11_grid_desc = *reinterpret_cast<const BGKGN0GN10GN11GridDesc*>(
|
||||
(const void*)p_b_gk_gn0_gn10_gn11_grid_desc);
|
||||
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc =
|
||||
*reinterpret_cast<const CGM10BM0BM1GN10BN0BN1GridDesc*>(
|
||||
(const void*)p_c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc);
|
||||
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToGM10GN10BlockClusterAdaptor*>(
|
||||
(const void*)p_c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseContraction::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_gk_gm0_gm10_gm11_grid_desc,
|
||||
b_gk_gn0_gn10_gn11_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
};
|
||||
@@ -0,0 +1,402 @@
|
||||
#include "common_header.hpp"
|
||||
#include "type_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_contraction_v1r2.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using FloatAB = typename get_type_from_type_id<static_cast<char>(CK_PARAM_IN_WEI_DATATYPE)>::type;
|
||||
using FloatAcc = typename get_type_from_type_id<static_cast<char>(CK_PARAM_ACC_DATATYPE)>::type;
|
||||
using FloatC = typename get_type_from_type_id<static_cast<char>(CK_PARAM_OUT_DATATYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
constexpr auto GN0 = Number<CK_PARAM_GN0>{};
|
||||
constexpr auto GK1 = Number<CK_PARAM_GK1>{};
|
||||
|
||||
constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11;
|
||||
constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11;
|
||||
constexpr index_t GK0PerBlock = CK_PARAM_GK0PerBlock;
|
||||
constexpr index_t BM1PerThreadBM11 = CK_PARAM_BM1PerThreadBM11;
|
||||
constexpr index_t BN1PerThreadBN11 = CK_PARAM_BN1PerThreadBN11;
|
||||
constexpr index_t BK0PerThread = CK_PARAM_BK0PerThread;
|
||||
constexpr index_t BM10BN10ThreadClusterBM100 = CK_PARAM_BM10BN10ThreadClusterBM100;
|
||||
constexpr index_t BM10BN10ThreadClusterBN100 = CK_PARAM_BM10BN10ThreadClusterBN100;
|
||||
constexpr index_t BM10BN10ThreadClusterBM101 = CK_PARAM_BM10BN10ThreadClusterBM101;
|
||||
constexpr index_t BM10BN10ThreadClusterBN101 = CK_PARAM_BM10BN10ThreadClusterBN101;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1>;
|
||||
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1>;
|
||||
using ABlockTransferThreadClusterArrangeOrder = Sequence<1, 2, 3, 0, 4>;
|
||||
using ABlockTransferSrcAccessOrder = Sequence<3, 2, 1, 0, 4>;
|
||||
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
Sequence<CK_PARAM_ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1>;
|
||||
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
Sequence<CK_PARAM_ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1>;
|
||||
using ABlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1>;
|
||||
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1>;
|
||||
using BBlockTransferThreadClusterArrangeOrder = Sequence<0, 4, 1, 2, 3>;
|
||||
using BBlockTransferSrcAccessOrder = Sequence<4, 3, 2, 0, 1>;
|
||||
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
Sequence<CK_PARAM_BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1>;
|
||||
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
Sequence<CK_PARAM_BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1>;
|
||||
using BBlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>;
|
||||
|
||||
using CThreadTransferSrcDstAccessOrder = Sequence<3, 4, 5, 0, 1, 2>;
|
||||
constexpr index_t CThreadTransferSrcDstVectorDim = 5;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||
|
||||
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP);
|
||||
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP);
|
||||
|
||||
extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw_prepare(
|
||||
index_t N,
|
||||
index_t C,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t K,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t ConvStrideH,
|
||||
index_t ConvStrideW,
|
||||
index_t ConvDilationH,
|
||||
index_t ConvDilationW,
|
||||
index_t InLeftPadH,
|
||||
index_t InLeftPadW,
|
||||
index_t InRightPadH,
|
||||
index_t InRightPadW,
|
||||
void* p_a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
void* p_b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
void* p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
void* p_c_grid_block_cluster_blockid_to_gm10_gn10)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
const index_t Ho =
|
||||
(Hi + InLeftPadH + InRightPadH - ConvDilationH * (Y - 1) - 1) / ConvStrideH + 1;
|
||||
const index_t Wo =
|
||||
(Wi + InLeftPadW + InRightPadW - ConvDilationW * (X - 1) - 1) / ConvStrideW + 1;
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C, Hi, Wi));
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C, Y, X));
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(ConvStrideH, ConvStrideW),
|
||||
make_tuple(ConvDilationH, ConvDilationW),
|
||||
make_tuple(InLeftPadH, InLeftPadW),
|
||||
make_tuple(InRightPadH, InRightPadW),
|
||||
GN0,
|
||||
GK1);
|
||||
|
||||
const auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
|
||||
const auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1];
|
||||
const auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
|
||||
|
||||
using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1);
|
||||
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using AGridIteratorHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using BGridIteratorHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
||||
|
||||
using BGridMoveSliceWindowIteratorHacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseContraction =
|
||||
GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set,
|
||||
AGridDesc_GK0_GM0_GM1_GK1,
|
||||
BGridDesc_GK0_GN0_GN1_GK1,
|
||||
CGridDesc_GM0_GM1_GN0_GN1,
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
GK0PerBlock,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM100,
|
||||
BM10BN10ThreadClusterBN100,
|
||||
BM10BN10ThreadClusterBM101,
|
||||
BM10BN10ThreadClusterBN101,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 =
|
||||
GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1);
|
||||
auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 =
|
||||
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||
auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
|
||||
GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
auto c_grid_block_cluster_blockid_to_gm10_gn10 =
|
||||
GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
{
|
||||
*static_cast<decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1)*>(
|
||||
p_a_grid_desc_gk0_gm0_gm10_gm11_gk1) = a_grid_desc_gk0_gm0_gm10_gm11_gk1;
|
||||
*static_cast<decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1)*>(
|
||||
p_b_grid_desc_gk0_gn0_gn10_gn11_gk1) = b_grid_desc_gk0_gn0_gn10_gn11_gk1;
|
||||
*static_cast<decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1)*>(
|
||||
p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1) = c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1;
|
||||
*static_cast<decltype(c_grid_block_cluster_blockid_to_gm10_gn10)*>(
|
||||
p_c_grid_block_cluster_blockid_to_gm10_gn10) =
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10;
|
||||
};
|
||||
};
|
||||
|
||||
extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void __CONSTANT__* p_a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
const void __CONSTANT__* p_b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
const void __CONSTANT__* p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
const void __CONSTANT__* p_c_grid_block_cluster_blockid_to_gm10_gn10)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
|
||||
constexpr auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3));
|
||||
constexpr auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
GN0,
|
||||
GK1);
|
||||
|
||||
constexpr auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
|
||||
constexpr auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1];
|
||||
constexpr auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
|
||||
|
||||
using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1);
|
||||
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using AGridIteratorHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using BGridIteratorHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
||||
|
||||
using BGridMoveSliceWindowIteratorHacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseContraction =
|
||||
GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set,
|
||||
AGridDesc_GK0_GM0_GM1_GK1,
|
||||
BGridDesc_GK0_GN0_GN1_GK1,
|
||||
CGridDesc_GM0_GM1_GN0_GN1,
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
GK0PerBlock,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM100,
|
||||
BM10BN10ThreadClusterBN100,
|
||||
BM10BN10ThreadClusterBM101,
|
||||
BM10BN10ThreadClusterBN101,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
|
||||
decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
|
||||
a_grid_desc_gk0_gm0_gm1_gk1));
|
||||
using BGridDesc_GK0_GN0_GN10_GN11_GK1 =
|
||||
decltype(GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
|
||||
b_grid_desc_gk0_gn0_gn1_gk1));
|
||||
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 =
|
||||
decltype(GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1));
|
||||
using CGridBlockCluster_BlockId_To_GM10_GN10 =
|
||||
decltype(GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1));
|
||||
|
||||
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 =
|
||||
*reinterpret_cast<const AGridDesc_GK0_GM0_GM10_GM11_GK1*>(
|
||||
(const void*)p_a_grid_desc_gk0_gm0_gm10_gm11_gk1);
|
||||
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 =
|
||||
*reinterpret_cast<const BGridDesc_GK0_GN0_GN10_GN11_GK1*>(
|
||||
(const void*)p_b_grid_desc_gk0_gn0_gn10_gn11_gk1);
|
||||
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
|
||||
*reinterpret_cast<const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1*>(
|
||||
(const void*)p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1);
|
||||
const auto c_grid_block_cluster_blockid_to_gm10_gn10 =
|
||||
*reinterpret_cast<const CGridBlockCluster_BlockId_To_GM10_GN10*>(
|
||||
(const void*)p_c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseContraction::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
};
|
||||
@@ -1,271 +0,0 @@
|
||||
#ifndef CONV_TUNABLES_HPP
|
||||
#define CONV_TUNABLES_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
|
||||
struct tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw
|
||||
{
|
||||
ck::index_t BlockSize; // usually not tunable
|
||||
|
||||
ck::index_t MPerBlock;
|
||||
ck::index_t NPerBlock;
|
||||
ck::index_t KPerBlock;
|
||||
|
||||
ck::index_t M1PerThread;
|
||||
ck::index_t N1PerThread;
|
||||
ck::index_t KPerThread;
|
||||
|
||||
ck::index_t M1N1ThreadClusterM10;
|
||||
ck::index_t M1N1ThreadClusterN10;
|
||||
ck::index_t M1N1ThreadClusterM11;
|
||||
ck::index_t M1N1ThreadClusterN11;
|
||||
|
||||
std::array<ck::index_t, 3> ABlockTransferThreadSliceLengths_K_M0_M1;
|
||||
std::array<ck::index_t, 3> ABlockTransferThreadClusterLengths_K_M0_M1;
|
||||
std::array<ck::index_t, 3> ABlockTransferThreadClusterArrangeOrder;
|
||||
std::array<ck::index_t, 3> ABlockTransferSrcAccessOrder;
|
||||
ck::index_t ABlockTransferSrcVectorDim;
|
||||
ck::index_t ABlockTransferSrcScalarPerVector;
|
||||
ck::index_t ABlockTransferDstScalarPerVector_M1;
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<ck::index_t, 3> BBlockTransferThreadSliceLengths_K_N0_N1;
|
||||
std::array<ck::index_t, 3> BBlockTransferThreadClusterLengths_K_N0_N1;
|
||||
std::array<ck::index_t, 3> BBlockTransferThreadClusterArrangeOrder;
|
||||
std::array<ck::index_t, 3> BBlockTransferSrcAccessOrder;
|
||||
ck::index_t BBlockTransferSrcVectorDim;
|
||||
ck::index_t BBlockTransferSrcScalarPerVector;
|
||||
ck::index_t BBlockTransferDstScalarPerVector_N1;
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<ck::index_t, 6> CThreadTransferSrcDstAccessOrder;
|
||||
ck::index_t CThreadTransferSrcDstVectorDim;
|
||||
ck::index_t CThreadTransferDstScalarPerVector;
|
||||
};
|
||||
|
||||
static tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw default_tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw = {
|
||||
256, 128, 128, 8, 4, 4, 1,
|
||||
8, 8, 2, 2, {4, 1, 1}, {2, 1, 128}, {2, 1, 0},
|
||||
{2, 1, 0}, 0, 4, 1, false, {4, 1, 1}, {2, 1, 128},
|
||||
{0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2},
|
||||
5, 1};
|
||||
|
||||
struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
|
||||
{
|
||||
ck::index_t BlockSize; // usually not tunable
|
||||
|
||||
ck::index_t MPerBlock;
|
||||
ck::index_t NPerBlock;
|
||||
ck::index_t KPerBlock;
|
||||
|
||||
ck::index_t MPerWave;
|
||||
ck::index_t NPerWave;
|
||||
ck::index_t K1;
|
||||
|
||||
ck::index_t MRepeat;
|
||||
ck::index_t NRepeat;
|
||||
|
||||
std::array<ck::index_t, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
|
||||
std::array<ck::index_t, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
|
||||
std::array<ck::index_t, 3> ABlockTransferThreadClusterArrangeOrder;
|
||||
std::array<ck::index_t, 3> ABlockTransferSrcAccessOrder;
|
||||
ck::index_t ABlockTransferSrcVectorDim;
|
||||
ck::index_t ABlockTransferSrcScalarPerVector;
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1;
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<ck::index_t, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
|
||||
std::array<ck::index_t, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
|
||||
std::array<ck::index_t, 3> BBlockTransferThreadClusterArrangeOrder;
|
||||
std::array<ck::index_t, 3> BBlockTransferSrcAccessOrder;
|
||||
ck::index_t BBlockTransferSrcVectorDim;
|
||||
ck::index_t BBlockTransferSrcScalarPerVector;
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1;
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<ck::index_t, 8> CThreadTransferSrcDstAccessOrder;
|
||||
ck::index_t CThreadTransferSrcDstVectorDim;
|
||||
ck::index_t CThreadTransferDstScalarPerVector;
|
||||
};
|
||||
|
||||
static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
|
||||
default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw = {
|
||||
256, // BlockSize
|
||||
128, // MPerBlock,
|
||||
128, // NPerBlock,
|
||||
4, // KPerBlock,
|
||||
32, // MPerWave,
|
||||
32, // NPerWave,
|
||||
4, // K1,
|
||||
2, // MRepeat,
|
||||
2, // NRepeat,
|
||||
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector,
|
||||
4, // ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
{0, 2, 1}, // BBlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
|
||||
1, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
4, // BBlockTransferDstScalarPerVector_K1
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun
|
||||
{3, 0, 1, 2, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
|
||||
7, // CThreadTransferSrcDstVectorDim,
|
||||
1 // CThreadTransferDstScalarPerVector
|
||||
};
|
||||
|
||||
struct tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
|
||||
{
|
||||
ck::index_t BlockSize; // usually not tunable
|
||||
|
||||
ck::index_t MPerBlock;
|
||||
ck::index_t NPerBlock;
|
||||
ck::index_t KPerBlock;
|
||||
|
||||
ck::index_t MPerWave;
|
||||
ck::index_t NPerWave;
|
||||
ck::index_t K1;
|
||||
|
||||
ck::index_t MRepeat;
|
||||
ck::index_t NRepeat;
|
||||
|
||||
std::array<ck::index_t, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
|
||||
std::array<ck::index_t, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
|
||||
std::array<ck::index_t, 3> ABlockTransferThreadClusterArrangeOrder;
|
||||
std::array<ck::index_t, 3> ABlockTransferSrcAccessOrder;
|
||||
ck::index_t ABlockTransferSrcVectorDim;
|
||||
ck::index_t ABlockTransferSrcScalarPerVector;
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1;
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<ck::index_t, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
|
||||
std::array<ck::index_t, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
|
||||
std::array<ck::index_t, 3> BBlockTransferThreadClusterArrangeOrder;
|
||||
std::array<ck::index_t, 3> BBlockTransferSrcAccessOrder;
|
||||
ck::index_t BBlockTransferSrcVectorDim;
|
||||
ck::index_t BBlockTransferSrcScalarPerVector;
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1;
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<ck::index_t, 8> CThreadTransferSrcDstAccessOrder;
|
||||
ck::index_t CThreadTransferSrcDstVectorDim;
|
||||
ck::index_t CThreadTransferDstScalarPerVector;
|
||||
};
|
||||
|
||||
static tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
|
||||
default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk = {
|
||||
256, // BlockSize
|
||||
128, // MPerBlock,
|
||||
128, // NPerBlock,
|
||||
4, // KPerBlock,
|
||||
32, // MPerWave,
|
||||
32, // NPerWave,
|
||||
4, // K1,
|
||||
2, // MRepeat,
|
||||
2, // NRepeat,
|
||||
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
4, // ABlockTransferSrcScalarPerVector,
|
||||
4, // ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
{1, 0, 2}, // BBlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
4, // BBlockTransferSrcScalarPerVector
|
||||
4, // BBlockTransferDstScalarPerVector_K1
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun
|
||||
{2, 3, 0, 1, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
|
||||
7, // CThreadTransferSrcDstVectorDim,
|
||||
1 // CThreadTransferDstScalarPerVector
|
||||
};
|
||||
|
||||
struct tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw
|
||||
{
|
||||
ck::index_t BlockSize;
|
||||
|
||||
ck::index_t GM1PerBlockGM11;
|
||||
ck::index_t GN1PerBlockGN11;
|
||||
ck::index_t KPerBlock;
|
||||
|
||||
ck::index_t M1PerThread;
|
||||
ck::index_t N1PerThread;
|
||||
ck::index_t KPerThread;
|
||||
|
||||
ck::index_t M1N1ThreadClusterM10;
|
||||
ck::index_t M1N1ThreadClusterN10;
|
||||
ck::index_t M1N1ThreadClusterM11;
|
||||
ck::index_t M1N1ThreadClusterN11;
|
||||
|
||||
std::array<ck::index_t, 4> ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11;
|
||||
std::array<ck::index_t, 4> ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11;
|
||||
std::array<ck::index_t, 4> ABlockTransferThreadClusterArrangeOrder;
|
||||
std::array<ck::index_t, 4> ABlockTransferSrcAccessOrder;
|
||||
ck::index_t ABlockTransferSrcVectorDim;
|
||||
ck::index_t ABlockTransferSrcScalarPerVector;
|
||||
ck::index_t ABlockTransferDstScalarPerVector_GM11;
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<ck::index_t, 4> BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11;
|
||||
std::array<ck::index_t, 4> BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11;
|
||||
std::array<ck::index_t, 4> BBlockTransferThreadClusterArrangeOrder;
|
||||
std::array<ck::index_t, 4> BBlockTransferSrcAccessOrder;
|
||||
ck::index_t BBlockTransferSrcVectorDim;
|
||||
ck::index_t BBlockTransferSrcScalarPerVector;
|
||||
ck::index_t BBlockTransferDstScalarPerVector_GN11;
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<ck::index_t, 6> CThreadTransferSrcDstAccessOrder;
|
||||
ck::index_t CThreadTransferSrcDstVectorDim;
|
||||
ck::index_t CThreadTransferDstScalarPerVector;
|
||||
};
|
||||
|
||||
static tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw default_tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw = {
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
8,
|
||||
4,
|
||||
4,
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
{4, 1, 1, 1},
|
||||
{2, 1, 1, 128},
|
||||
{3, 2, 1, 0},
|
||||
{3, 2, 1, 0},
|
||||
0,
|
||||
4,
|
||||
1,
|
||||
false,
|
||||
{1, 4, 1, 1},
|
||||
{8, 1, 1, 32},
|
||||
{0, 3, 2, 1},
|
||||
{0, 3, 2, 1},
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
{3, 4, 5, 0, 1, 2},
|
||||
5,
|
||||
1};
|
||||
|
||||
static inline int
|
||||
conv_hw_out_size(int hw_in_size, int leftPad, int rightPad, int dilation, int yx_size, int stride)
|
||||
{
|
||||
return (hw_in_size + leftPad + rightPad - dilation * (yx_size - 1) - 1) / stride + 1;
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,520 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_dynamic_gemm_v1r2.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// cdata = 16, BlockSize = 64, 16x64x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 16;
|
||||
constexpr index_t GemmNPerBlockN1 = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 2;
|
||||
constexpr index_t GemmN1PerThreadN111 = 2;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 2;
|
||||
#elif 0
|
||||
// cdata = 32, BlockSize = 64, 16x128x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 16;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 2;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 2;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 64, 16x256x2
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 16;
|
||||
constexpr index_t GemmNPerBlockN1 = 256;
|
||||
constexpr index_t GemmKPerBlock = 2;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 1;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 16;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<2, 1, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 64, 16x256x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 16;
|
||||
constexpr index_t GemmNPerBlockN1 = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 1;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 16;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 32x256x4
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 32;
|
||||
constexpr index_t GemmNPerBlockN1 = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 16;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 32>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 32x256x8
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 32;
|
||||
constexpr index_t GemmNPerBlockN1 = 256;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 16;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<2, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 32>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<8, 1, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 64>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<8, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
#if 0
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
#else
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
#endif
|
||||
|
||||
#else
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1(wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
#endif
|
||||
|
||||
const auto wei_gemmk_gemmm_grid_desc = descs[I0];
|
||||
const auto in_gemmk_gemmn_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_gemm_v1r2<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(wei_gemmk_gemmm_grid_desc),
|
||||
decltype(in_gemmk_gemmn_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlockM1,
|
||||
GemmNPerBlockN1,
|
||||
GemmKPerBlock,
|
||||
GemmM1PerThreadM111,
|
||||
GemmN1PerThreadN111,
|
||||
GemmKPerThread,
|
||||
GemmM11N11ThreadClusterM1100,
|
||||
GemmM11N11ThreadClusterN1100,
|
||||
GemmM11N11ThreadClusterM1101,
|
||||
GemmM11N11ThreadClusterN1101,
|
||||
GemmABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
GemmABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
Sequence<1, 2, 0>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<1, 2, 0>, // ABlockTransferSrcAccessOrder
|
||||
0, // ABlockTransferSrcVectorDim
|
||||
GemmABlockTransferSrcScalarPerVector_K,
|
||||
GemmABlockTransferDstScalarPerVector_M1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
GemmBBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
Sequence<1, 2, 0>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<1, 2, 0>, // BBlockTransferSrcAccessOrder
|
||||
0, // BBlockTransferSrcVectorDim
|
||||
GemmBBlockTransferSrcScalarPerVector_K,
|
||||
GemmBBlockTransferDstScalarPerVector_N1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
2, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_M11,
|
||||
decltype(wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_iterator_hacks),
|
||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
|
||||
decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk_gemmm_grid_desc,
|
||||
in_gemmk_gemmn_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_iterator_hacks,
|
||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks,
|
||||
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
@@ -1,240 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_dynamic_contraction_v1r1.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 1
|
||||
// cdata = 64, BlockSize = 256, [8, 1, 128] * [8, 4, 32] = [1, 128, 4, 32]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t N0 = 4;
|
||||
|
||||
constexpr index_t GemmGM1PerBlockGM11 = 128;
|
||||
constexpr index_t GemmGN1PerBlockGN11 = 32;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11 = Sequence<4, 1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11 = Sequence<2, 1, 1, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GM11 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11 = Sequence<1, 4, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11 = Sequence<8, 1, 1, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GN11 = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GN11 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, [8, 1, 128] * [8, 8, 16] = [1, 128, 8, 16]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t N0 = 8;
|
||||
|
||||
constexpr index_t GemmGM1PerBlockGM11 = 128;
|
||||
constexpr index_t GemmGN1PerBlockGN11 = 16;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11 = Sequence<4, 1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11 = Sequence<2, 1, 1, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GM11 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11 = Sequence<1, 4, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11 = Sequence<8, 2, 1, 16>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GN11 = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GN11 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1;
|
||||
#endif
|
||||
|
||||
const auto descs = transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad<N0>(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
const auto wei_gk_gm0_gm1_grid_desc = descs[I0];
|
||||
const auto in_gk_gn0_gn1_grid_desc = descs[I1];
|
||||
const auto out_gm0_gm1_gn0_gn1_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gk_gm0_gm10_gm11_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gk_gn0_gn10_gn11_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_gm10_bm0_bm1_gn10_bn0_bn1_grid_iterator_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gk_gm0_gm10_gm11_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gk_gn0_gn10_gn11_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_contraction_v1r1<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(wei_gk_gm0_gm1_grid_desc),
|
||||
decltype(in_gk_gn0_gn1_grid_desc),
|
||||
decltype(out_gm0_gm1_gn0_gn1_grid_desc),
|
||||
GemmGM1PerBlockGM11,
|
||||
GemmGN1PerBlockGN11,
|
||||
GemmKPerBlock,
|
||||
GemmM1PerThreadM111,
|
||||
GemmN1PerThreadN111,
|
||||
GemmKPerThread,
|
||||
GemmM11N11ThreadClusterM1100,
|
||||
GemmM11N11ThreadClusterN1100,
|
||||
GemmM11N11ThreadClusterM1101,
|
||||
GemmM11N11ThreadClusterN1101,
|
||||
GemmABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
|
||||
GemmABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
|
||||
Sequence<3, 2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<3, 2, 1, 0>, // ABlockTransferSrcAccessOrder
|
||||
0, // ABlockTransferSrcVectorDim
|
||||
GemmABlockTransferSrcScalarPerVector_GK,
|
||||
GemmABlockTransferDstScalarPerVector_GM11,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
|
||||
GemmBBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
|
||||
Sequence<0, 3, 2, 1>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<0, 3, 2, 1>, // BBlockTransferSrcAccessOrder
|
||||
3, // BBlockTransferSrcVectorDim
|
||||
GemmBBlockTransferSrcScalarPerVector_GN11,
|
||||
GemmBBlockTransferDstScalarPerVector_GN11,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_BN1,
|
||||
decltype(wei_gk_gm0_gm10_gm11_grid_iterator_hacks),
|
||||
decltype(in_gk_gn0_gn10_gn11_grid_iterator_hacks),
|
||||
decltype(out_gm10_bm0_bm1_gn10_bn0_bn1_grid_iterator_hacks),
|
||||
decltype(wei_gk_gm0_gm10_gm11_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_gk_gn0_gn10_gn11_grid_move_slice_window_iterator_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_gk_gm0_gm1_grid_desc,
|
||||
in_gk_gn0_gn1_grid_desc,
|
||||
out_gm0_gm1_gn0_gn1_grid_desc,
|
||||
wei_gk_gm0_gm10_gm11_grid_iterator_hacks,
|
||||
in_gk_gn0_gn10_gn11_grid_iterator_hacks,
|
||||
out_gm10_bm0_bm1_gn10_bn0_bn1_grid_iterator_hacks,
|
||||
wei_gk_gm0_gm10_gm11_grid_move_slice_window_iterator_hacks,
|
||||
in_gk_gn0_gn10_gn11_grid_move_slice_window_iterator_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -1,404 +0,0 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
#include "olc_driver_common.hpp"
|
||||
#include "conv_tunables.hpp"
|
||||
|
||||
#include "handle.hpp"
|
||||
|
||||
namespace detail_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw {
|
||||
|
||||
template <typename TInWei, typename TAcc, typename TOut>
|
||||
static std::string get_network_config_string_from_types()
|
||||
{
|
||||
std::string out;
|
||||
|
||||
out += static_cast<char>(Driver::get_typeid_from_type<TInWei>()) +
|
||||
static_cast<char>(Driver::get_typeid_from_type<TAcc>()) +
|
||||
static_cast<char>(Driver::get_typeid_from_type<TOut>());
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
static std::string
|
||||
get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw* pt)
|
||||
{
|
||||
std::string out("TUN_");
|
||||
|
||||
out += std::to_string(pt->BlockSize) + "_";
|
||||
|
||||
out += std::to_string(pt->GM1PerBlockGM11) + "x" + std::to_string(pt->GN1PerBlockGN11) + "x" +
|
||||
std::to_string(pt->KPerBlock) + "_";
|
||||
out += std::to_string(pt->M1PerThread) + "x" + std::to_string(pt->N1PerThread) + "x" +
|
||||
std::to_string(pt->KPerThread) + "_";
|
||||
out += std::to_string(pt->M1N1ThreadClusterM10) + "x" +
|
||||
std::to_string(pt->M1N1ThreadClusterN10) + "x" +
|
||||
std::to_string(pt->M1N1ThreadClusterM11) + "x" +
|
||||
std::to_string(pt->M1N1ThreadClusterN11) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[2]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[3]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[2]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[3]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[3]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "x" +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[3]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferSrcVectorDim) + "_";
|
||||
out += std::to_string(pt->ABlockTransferSrcScalarPerVector) + "_";
|
||||
out += std::to_string(pt->ABlockTransferDstScalarPerVector_GM11) + "_";
|
||||
out += std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[2]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[3]);
|
||||
|
||||
out += std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[2]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[3]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[3]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "x" +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[3]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferSrcVectorDim) + "_";
|
||||
out += std::to_string(pt->BBlockTransferSrcScalarPerVector) + "_";
|
||||
out += std::to_string(pt->BBlockTransferDstScalarPerVector_GN11) + "_";
|
||||
out += std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun) + "_";
|
||||
|
||||
out += std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "_";
|
||||
|
||||
out += std::to_string(pt->CThreadTransferSrcDstVectorDim) + "_";
|
||||
out += std::to_string(pt->CThreadTransferDstScalarPerVector);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
template <typename TInWei, typename TAcc, typename TOut>
|
||||
static std::string get_definition_string_from_types()
|
||||
{
|
||||
std::string out;
|
||||
|
||||
out += " -DCK_PARAM_IN_WEI_DATATYPE=" + std::to_string(Driver::get_typeid_from_type<TInWei>()) +
|
||||
" -DCK_PARAM_CONV_COMPTYPE=" + std::to_string(Driver::get_typeid_from_type<TAcc>()) +
|
||||
" -DCK_PARAM_OUT_DATATYPE=" + std::to_string(Driver::get_typeid_from_type<TOut>());
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
static std::string
|
||||
get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw* pt)
|
||||
{
|
||||
std::string out;
|
||||
|
||||
out += " -DCK_PARAM_BlockSize=" + std::to_string(pt->BlockSize);
|
||||
|
||||
out += " -DCK_PARAM_GM1PerBlockGM11=" + std::to_string(pt->GM1PerBlockGM11) +
|
||||
" -DCK_PARAM_GN1PerBlockGN11=" + std::to_string(pt->GN1PerBlockGN11) +
|
||||
" -DCK_PARAM_KPerBlock=" + std::to_string(pt->KPerBlock);
|
||||
out += " -DCK_PARAM_M1PerThread=" + std::to_string(pt->M1PerThread) +
|
||||
" -DCK_PARAM_N1PerThread=" + std::to_string(pt->N1PerThread) +
|
||||
" -DCK_PARAM_KPerThread=" + std::to_string(pt->KPerThread);
|
||||
|
||||
out += " -DCK_PARAM_M1N1ThreadClusterM10=" + std::to_string(pt->M1N1ThreadClusterM10) +
|
||||
" -DCK_PARAM_M1N1ThreadClusterN10=" + std::to_string(pt->M1N1ThreadClusterN10) +
|
||||
" -DCK_PARAM_M1N1ThreadClusterM11=" + std::to_string(pt->M1N1ThreadClusterM11) +
|
||||
" -DCK_PARAM_M1N1ThreadClusterN11=" + std::to_string(pt->M1N1ThreadClusterN11);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11=" +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[2]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[3]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11=" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[2]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[3]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferThreadClusterArrangeOrder=" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[3]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferSrcAccessOrder=" +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "," +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[3]);
|
||||
|
||||
out +=
|
||||
" -DCK_PARAM_ABlockTransferSrcVectorDim=" + std::to_string(pt->ABlockTransferSrcVectorDim);
|
||||
out += " -DCK_PARAM_ABlockTransferSrcScalarPerVector=" +
|
||||
std::to_string(pt->ABlockTransferSrcScalarPerVector);
|
||||
out += " -DCK_PARAM_ABlockTransferDstScalarPerVector_GM11=" +
|
||||
std::to_string(pt->ABlockTransferDstScalarPerVector_GM11);
|
||||
out += " -DCK_PARAM_AThreadTransferSrcResetCoordinateAfterRun=" +
|
||||
std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11=" +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[2]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[3]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11=" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[2]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[3]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferThreadClusterArrangeOrder=" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[3]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferSrcAccessOrder=" +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "," +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[3]);
|
||||
|
||||
out +=
|
||||
" -DCK_PARAM_BBlockTransferSrcVectorDim=" + std::to_string(pt->BBlockTransferSrcVectorDim);
|
||||
out += " -DCK_PARAM_BBlockTransferSrcScalarPerVector=" +
|
||||
std::to_string(pt->BBlockTransferSrcScalarPerVector);
|
||||
out += " -DCK_PARAM_BBlockTransferDstScalarPerVector_GN11=" +
|
||||
std::to_string(pt->BBlockTransferDstScalarPerVector_GN11);
|
||||
out += " -DCK_PARAM_BThreadTransferSrcResetCoordinateAfterRun=" +
|
||||
std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
out += " -DCK_PARAM_CThreadTransferSrcDstAccessOrder=" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]);
|
||||
|
||||
out += " -DCK_PARAM_CThreadTransferSrcDstVectorDim=" +
|
||||
std::to_string(pt->CThreadTransferSrcDstVectorDim);
|
||||
out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" +
|
||||
std::to_string(pt->CThreadTransferDstScalarPerVector);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
} // namespace detail_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw_olc(
|
||||
olCompile::Handle* handle,
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
const tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw* tunable,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
using namespace detail_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw;
|
||||
using size_t = std::size_t;
|
||||
|
||||
constexpr index_t N0 = 4; // this could not be a tunable so far
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// The follow codes are only used for computing the grid_size, hasMainKBlockLoop,
|
||||
// hasDoubleTailKBlockLoop
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
||||
|
||||
const auto descs = transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad<N0>(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
const auto a_gk_gm0_gm1_grid_desc = descs[I0];
|
||||
const auto c_gm0_gm1_gn0_gn1_grid_desc = descs[I2];
|
||||
|
||||
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
|
||||
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
|
||||
const auto GK = a_gk_gm0_gm1_grid_desc.GetLength(I0);
|
||||
|
||||
const index_t grid_size = (GM1 / tunable->GM1PerBlockGM11) * (GN1 / tunable->GN1PerBlockGN11);
|
||||
const bool hasMainKBlockLoop = ((GK + tunable->KPerBlock) / (2 * tunable->KPerBlock) > 1);
|
||||
const bool hasDoubleTailKBlockLoop = ((GK / tunable->KPerBlock) % 2 == 0);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// these buffers are usually provided by the user application
|
||||
DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
// these are workspace buffers that should be expressed to the user by the corresponding
|
||||
// workspace API
|
||||
DeviceMem workspace_buf(4096);
|
||||
|
||||
void* a_gk_gm0_gm10_gm11_grid_desc_dev_buf = workspace_buf.GetDeviceBuffer();
|
||||
void* b_gk_gn0_gn10_gn11_grid_desc_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 1024);
|
||||
void* c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 2048);
|
||||
void* c_blockid_to_gm10_gn10_block_cluster_adaptor_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 3072);
|
||||
|
||||
const std::vector<size_t> vld = {static_cast<size_t>(tunable->BlockSize), 1, 1};
|
||||
const std::vector<size_t> vgd1 = {static_cast<size_t>(tunable->BlockSize), 1, 1};
|
||||
const std::vector<size_t> vgd2 = {static_cast<size_t>(grid_size * tunable->BlockSize), 1, 1};
|
||||
|
||||
std::string program_name = "dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.cpp";
|
||||
std::string algo_name = "implicit_gemm_conv_fwd_v4r4_nchw";
|
||||
|
||||
std::string param = " -std=c++17 ";
|
||||
std::string network_config;
|
||||
|
||||
param += get_definition_string_from_types<TInWei, TAcc, TOut>() +
|
||||
" -DCK_PARAM_HAS_MAIN_KBLOCK_LOOP=" + std::to_string(hasMainKBlockLoop) +
|
||||
" -DCK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP=" + std::to_string(hasDoubleTailKBlockLoop) +
|
||||
" -DCK_PARAM_N0=" + std::to_string(N0) + " " +
|
||||
get_definition_string_from_tunable(tunable);
|
||||
network_config = get_network_config_string_from_types<TInWei, TAcc, TOut>() + "_V" +
|
||||
std::to_string(hasDoubleTailKBlockLoop) + "_" + std::to_string(N0) + "_" +
|
||||
get_network_config_string_from_tunable(tunable);
|
||||
|
||||
std::vector<float> kernel1_times;
|
||||
std::vector<float> kernel2_times;
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
KernelTimer timer1, timer2;
|
||||
std::string kernel_name;
|
||||
|
||||
kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw_prepare";
|
||||
auto network_config_1 = network_config + "_1";
|
||||
|
||||
timer1.Start();
|
||||
handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)(
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I0]),
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I1]),
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I2]),
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I3]),
|
||||
static_cast<index_t>(wei_k_c_y_x_lengths[I0]),
|
||||
static_cast<index_t>(wei_k_c_y_x_lengths[I2]),
|
||||
static_cast<index_t>(wei_k_c_y_x_lengths[I3]),
|
||||
conv_strides[I0],
|
||||
conv_strides[I1],
|
||||
conv_dilations[I0],
|
||||
conv_dilations[I1],
|
||||
in_left_pads[I0],
|
||||
in_left_pads[I1],
|
||||
in_right_pads[I0],
|
||||
in_right_pads[I1],
|
||||
a_gk_gm0_gm10_gm11_grid_desc_dev_buf,
|
||||
b_gk_gn0_gn10_gn11_grid_desc_dev_buf,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc_dev_buf,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor_dev_buf);
|
||||
timer2.End();
|
||||
|
||||
kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw";
|
||||
auto network_config_2 = network_config + "_2";
|
||||
|
||||
timer2.Start();
|
||||
handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)(
|
||||
reinterpret_cast<const TInWei*>(wei_k_c_y_x_dev_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<const TInWei*>(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<TOut*>(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()),
|
||||
(const void*)(a_gk_gm0_gm10_gm11_grid_desc_dev_buf),
|
||||
(const void*)(b_gk_gn0_gn10_gn11_grid_desc_dev_buf),
|
||||
(const void*)(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc_dev_buf),
|
||||
(const void*)(c_blockid_to_gm10_gn10_block_cluster_adaptor_dev_buf));
|
||||
timer2.End();
|
||||
|
||||
kernel1_times.push_back(timer1.GetElapsedTime());
|
||||
kernel2_times.push_back(timer2.GetElapsedTime());
|
||||
}
|
||||
|
||||
{
|
||||
auto ave_time1 = Driver::get_effective_average(kernel1_times);
|
||||
auto ave_time2 = Driver::get_effective_average(kernel2_times);
|
||||
|
||||
const auto N = in_n_c_hi_wi_lengths[I0];
|
||||
const auto C = in_n_c_hi_wi_lengths[I1];
|
||||
|
||||
const auto K = out_n_k_ho_wo_lengths[I1];
|
||||
const auto Ho = out_n_k_ho_wo_lengths[I2];
|
||||
const auto Wo = out_n_k_ho_wo_lengths[I3];
|
||||
|
||||
const auto Y = wei_k_c_y_x_lengths[I2];
|
||||
const auto X = wei_k_c_y_x_lengths[I3];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2);
|
||||
|
||||
std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", "
|
||||
<< ave_time2 << "), " << perf << " TFlop/s" << std::endl;
|
||||
};
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
4
host/CMakeLists.txt
Normal file
4
host/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
add_subdirectory(host_tensor)
|
||||
add_subdirectory(online_compilation)
|
||||
add_subdirectory(driver_offline)
|
||||
add_subdirectory(driver_online)
|
||||
21
host/driver_offline/CMakeLists.txt
Normal file
21
host/driver_offline/CMakeLists.txt
Normal file
@@ -0,0 +1,21 @@
|
||||
include_directories(BEFORE
|
||||
include
|
||||
${PROJECT_SOURCE_DIR}/host/host_tensor/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
|
||||
${PROJECT_SOURCE_DIR}/external/rocm/include
|
||||
${PROJECT_SOURCE_DIR}/external/half/include
|
||||
)
|
||||
|
||||
set(CONV_FWD_DRIVER_OFFLINE_SOURCE conv_fwd_driver_offline.cpp)
|
||||
set(CONV_BWD_DRIVER_OFFLINE_SOURCE conv_bwd_driver_offline.cpp)
|
||||
|
||||
add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE})
|
||||
add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE})
|
||||
|
||||
target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor)
|
||||
target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor)
|
||||
@@ -13,34 +13,28 @@
|
||||
#include "host_conv.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
#define USE_DYNAMIC_MODE 1
|
||||
#define USE_CONV_FWD_V4R4_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4_NHWC 0
|
||||
#define USE_CONV_FWD_V4R4_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4R2_NHWC 0
|
||||
#define USE_CONV_FWD_V4R5_NCHW 0
|
||||
#define USE_CONV_FWD_V4R5R2_NCHW 0
|
||||
#define USE_CONV_FWD_V6R1_NCHW 0
|
||||
#define USE_CONV_FWD_V5R1_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1
|
||||
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
|
||||
|
||||
enum ConvForwardAlgo
|
||||
{
|
||||
V4R4NCHW, // 0
|
||||
V4R4NHWC, // 1
|
||||
V4R4R2NHWC, // 2
|
||||
V4R5NCHW, // 3
|
||||
V4R5R2NCHW, // 4
|
||||
V5R1NCHW, // 5
|
||||
V4R4R2XDLNCHW, // 6
|
||||
V4R4R4XDLNHWC // 7
|
||||
V4R4R2NHWC, // 1
|
||||
V6R1NCHW, // 2
|
||||
V5R1NCHW, // 3
|
||||
V4R4R2XDLNCHW, // 4
|
||||
V4R4R4XDLNHWC // 5
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -132,7 +126,7 @@ int main(int argc, char* argv[])
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
@@ -323,32 +317,6 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R4_NHWC
|
||||
if(algo == ConvForwardAlgo::V4R4NHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R4R2_NHWC
|
||||
if(algo == ConvForwardAlgo::V4R4R2NHWC)
|
||||
{
|
||||
@@ -376,8 +344,8 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R5_NCHW
|
||||
if(algo == ConvForwardAlgo::V4R5NCHW)
|
||||
#if USE_CONV_FWD_V6R1_NCHW
|
||||
if(algo == ConvForwardAlgo::V6R1NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
@@ -386,7 +354,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw<in_data_t,
|
||||
device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
@@ -402,33 +370,6 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R5R2_NCHW
|
||||
if(algo == ConvForwardAlgo::V4R5R2NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V5R1_NCHW
|
||||
if(algo == ConvForwardAlgo::V5R1NCHW)
|
||||
{
|
||||
@@ -0,0 +1,283 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 4;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
#if 1
|
||||
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad
|
||||
#else
|
||||
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1
|
||||
#endif
|
||||
<TInWei, GemmMPerBlock, GemmNPerBlock, GemmMPerWave, GemmNPerWave, GemmKPack>(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
#if 0
|
||||
float ave_time = launch_kernel_dynamic_gemm_xdlops_v1
|
||||
#else
|
||||
float ave_time = launch_kernel_dynamic_gemm_xdlops_v2
|
||||
#endif
|
||||
<BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(descs[I0]),
|
||||
decltype(descs[I1]),
|
||||
decltype(descs[I2]),
|
||||
decltype(descs[I3]),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmKPack,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmABlockTransferDstScalarPerVector_KPack,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<1, 0, 2>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_KPack,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused
|
||||
// with MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<2, 3, 0, 1>,
|
||||
3,
|
||||
GemmCThreadTransferDstScalarPerVector_GemmN1,
|
||||
decltype(descs[I4]),
|
||||
decltype(descs[I5]),
|
||||
decltype(descs[I6]),
|
||||
decltype(descs[I7]),
|
||||
decltype(descs[I8])>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
descs[I0],
|
||||
descs[I1],
|
||||
descs[I2],
|
||||
descs[I3],
|
||||
descs[I4],
|
||||
descs[I5],
|
||||
descs[I6],
|
||||
descs[I7],
|
||||
descs[I8],
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,240 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_dynamic_gemm_xdlops_v2r2.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_gemm_xdlops_v2r2<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1>,
|
||||
2,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
||||
out_m0_m1_m2_n_grid_iterator_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,305 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_dynamic_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
||||
out_m0_m1_m2_n_grid_iterator_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r5r2_nchw_kcyx_nkhw.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_dynamic_contraction_v1r2.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
@@ -14,7 +14,7 @@ template <typename TInWei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw(
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
@@ -43,11 +43,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw(
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
const auto in_desc_n_c_hi_wi =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc =
|
||||
const auto wei_desc_k_c_y_x =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
const auto out_desc_n_k_ho_wo =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 1
|
||||
@@ -58,32 +58,32 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw(
|
||||
constexpr index_t GN0 = 4;
|
||||
constexpr index_t GK1 = 1;
|
||||
|
||||
constexpr index_t GemmGM1PerBlockGM11 = 128;
|
||||
constexpr index_t GemmGN1PerBlockGN11 = 32;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GM1PerBlockGM11 = 128;
|
||||
constexpr index_t GN1PerBlockGN11 = 32;
|
||||
constexpr index_t GK0PerBlock = 8;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t BM1PerThreadBM11 = 4;
|
||||
constexpr index_t BN1PerThreadBN11 = 4;
|
||||
constexpr index_t BK0PerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
constexpr index_t BM10BN10ThreadClusterBM100 = 8;
|
||||
constexpr index_t BM10BN10ThreadClusterBN100 = 8;
|
||||
constexpr index_t BM10BN10ThreadClusterBM101 = 2;
|
||||
constexpr index_t BM10BN10ThreadClusterBN101 = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
|
||||
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
|
||||
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
|
||||
|
||||
using GemmABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
|
||||
using GemmABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
|
||||
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
|
||||
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 1>;
|
||||
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
|
||||
|
||||
using GemmBBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
using GemmBBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1;
|
||||
#elif 1
|
||||
// [8, 1, 128, 2] * [8, 4, 32, 2] = [1, 128, 4, 32] for fp16
|
||||
// cdata = 64, BlockSize = 256
|
||||
@@ -92,48 +92,48 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw(
|
||||
constexpr index_t GN0 = 4;
|
||||
constexpr index_t GK1 = 2;
|
||||
|
||||
constexpr index_t GemmGM1PerBlockGM11 = 128;
|
||||
constexpr index_t GemmGN1PerBlockGN11 = 32;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GM1PerBlockGM11 = 128;
|
||||
constexpr index_t GN1PerBlockGN11 = 32;
|
||||
constexpr index_t GK0PerBlock = 8;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t BM1PerThreadBM11 = 4;
|
||||
constexpr index_t BN1PerThreadBN11 = 4;
|
||||
constexpr index_t BK0PerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
constexpr index_t BM10BN10ThreadClusterBM100 = 8;
|
||||
constexpr index_t BM10BN10ThreadClusterBN100 = 8;
|
||||
constexpr index_t BM10BN10ThreadClusterBM101 = 2;
|
||||
constexpr index_t BM10BN10ThreadClusterBN101 = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
|
||||
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>;
|
||||
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
|
||||
|
||||
using GemmABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
|
||||
using GemmABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 2>;
|
||||
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
|
||||
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 2>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
|
||||
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 2>;
|
||||
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
|
||||
|
||||
using GemmBBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
using GemmBBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 2>;
|
||||
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 2>;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GN0>{},
|
||||
Number<GK1>{});
|
||||
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_desc_k_c_y_x,
|
||||
in_desc_n_c_hi_wi,
|
||||
out_desc_n_k_ho_wo,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GN0>{},
|
||||
Number<GK1>{});
|
||||
|
||||
const auto wei_gk0_gm0_gm1_gk1_grid_desc = descs[I0];
|
||||
const auto in_gk0_gn0_gn1_gk1_grid_desc = descs[I1];
|
||||
const auto out_gm0_gm1_gn0_gn1_grid_desc = descs[I2];
|
||||
const auto wei_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
|
||||
const auto in_grid_desc_gk0_gn0_gn1_gk1 = descs[I1];
|
||||
const auto out_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_grid_iterator_hacks =
|
||||
@@ -189,36 +189,36 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw(
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(wei_gk0_gm0_gm1_gk1_grid_desc),
|
||||
decltype(in_gk0_gn0_gn1_gk1_grid_desc),
|
||||
decltype(out_gm0_gm1_gn0_gn1_grid_desc),
|
||||
GemmGM1PerBlockGM11,
|
||||
GemmGN1PerBlockGN11,
|
||||
GemmKPerBlock,
|
||||
GemmM1PerThreadM111,
|
||||
GemmN1PerThreadN111,
|
||||
GemmKPerThread,
|
||||
GemmM11N11ThreadClusterM1100,
|
||||
GemmM11N11ThreadClusterN1100,
|
||||
GemmM11N11ThreadClusterM1101,
|
||||
GemmM11N11ThreadClusterN1101,
|
||||
GemmABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
GemmABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
decltype(wei_grid_desc_gk0_gm0_gm1_gk1),
|
||||
decltype(in_grid_desc_gk0_gn0_gn1_gk1),
|
||||
decltype(out_grid_desc_gm0_gm1_gn0_gn1),
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
GK0PerBlock,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM100,
|
||||
BM10BN10ThreadClusterBN100,
|
||||
BM10BN10ThreadClusterBM101,
|
||||
BM10BN10ThreadClusterBN101,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
Sequence<1, 2, 3, 0, 4>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<3, 2, 1, 0, 4>, // ABlockTransferSrcAccessOrder
|
||||
GemmABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
GemmABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
Sequence<0, 1, 2, 3, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder
|
||||
GemmBBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
Sequence<0, 4, 1, 2, 3>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<4, 3, 2, 0, 1>, // BBlockTransferSrcAccessOrder
|
||||
GemmBBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
GemmBBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_BN1,
|
||||
CThreadTransferDstScalarPerVector_BN1,
|
||||
decltype(wei_grid_iterator_hacks),
|
||||
decltype(in_grid_iterator_hacks),
|
||||
decltype(out_grid_iterator_hacks),
|
||||
@@ -227,9 +227,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw(
|
||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_gk0_gm0_gm1_gk1_grid_desc,
|
||||
in_gk0_gn0_gn1_gk1_grid_desc,
|
||||
out_gm0_gm1_gn0_gn1_grid_desc,
|
||||
wei_grid_desc_gk0_gm0_gm1_gk1,
|
||||
in_grid_desc_gk0_gn0_gn1_gk1,
|
||||
out_grid_desc_gm0_gm1_gn0_gn1,
|
||||
wei_grid_iterator_hacks,
|
||||
in_grid_iterator_hacks,
|
||||
out_grid_iterator_hacks,
|
||||
@@ -238,7 +238,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw(
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
|
||||
in_desc_n_c_hi_wi, wei_desc_k_c_y_x, out_desc_n_k_ho_wo) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
21
host/driver_online/CMakeLists.txt
Normal file
21
host/driver_online/CMakeLists.txt
Normal file
@@ -0,0 +1,21 @@
|
||||
include_directories(BEFORE
|
||||
include
|
||||
${PROJECT_BINARY_DIR}/host/online_compilation/include
|
||||
${PROJECT_SOURCE_DIR}/host/online_compilation/include
|
||||
${PROJECT_SOURCE_DIR}/host/host_tensor/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
|
||||
${PROJECT_SOURCE_DIR}/external/rocm/include
|
||||
${PROJECT_SOURCE_DIR}/external/half/include
|
||||
)
|
||||
|
||||
set(CONV_FWD_DRIVER_ONLINE_SOURCE conv_fwd_driver_online.cpp)
|
||||
|
||||
add_executable(conv_fwd_driver_online ${CONV_FWD_DRIVER_ONLINE_SOURCE})
|
||||
|
||||
target_link_libraries(conv_fwd_driver_online PRIVATE host_tensor)
|
||||
target_link_libraries(conv_fwd_driver_online PRIVATE online_compilation)
|
||||
@@ -12,26 +12,22 @@
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
|
||||
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
#define USE_CONV_FWD_V4R4_NCHW 1
|
||||
#define USE_CONV_FWD_V4R5_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4_XDLOPS_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4_XDLOPS_NHWC 1
|
||||
|
||||
#include "conv_tunables.hpp"
|
||||
#include "handle.hpp"
|
||||
#include "hipCheck.hpp"
|
||||
#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
#define USE_CONV_FWD_V4R4_NCHW 1
|
||||
#define USE_CONV_FWD_V6R1_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4_XDLOPS_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4_XDLOPS_NHWC 1
|
||||
|
||||
enum ConvForwardAlgo
|
||||
{
|
||||
V4R4NCHW, // 0
|
||||
V4R5NCHW, // 1
|
||||
V6R1NCHW, // 1
|
||||
V4R4XDLNCHW, // 2
|
||||
V4R4XDLNHWC // 3
|
||||
};
|
||||
@@ -94,15 +90,17 @@ int main(int argc, char* argv[])
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
|
||||
#if 1
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 1
|
||||
constexpr index_t in_vector_size = 16;
|
||||
using in_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
using in_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
#elif 1
|
||||
using in_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
||||
@@ -230,9 +228,9 @@ int main(int argc, char* argv[])
|
||||
tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw* tunable =
|
||||
&default_tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw;
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw_olc<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
handle,
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
@@ -249,8 +247,8 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R5_NCHW
|
||||
if(algo == ConvForwardAlgo::V4R5NCHW)
|
||||
#if USE_CONV_FWD_V6R1_NCHW
|
||||
if(algo == ConvForwardAlgo::V6R1NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
@@ -259,12 +257,11 @@ int main(int argc, char* argv[])
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw* tunable =
|
||||
&default_tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw;
|
||||
const auto tunable = tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw{};
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw_olc<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
handle,
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
@@ -294,22 +291,22 @@ int main(int argc, char* argv[])
|
||||
tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* tunable =
|
||||
&default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw;
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_olc<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
handle,
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
tunable,
|
||||
nrepeat);
|
||||
online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw<
|
||||
in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(handle,
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
tunable,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -326,22 +323,22 @@ int main(int argc, char* argv[])
|
||||
tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* tunable =
|
||||
&default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk;
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_olc<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
handle,
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
tunable,
|
||||
nrepeat);
|
||||
online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk<
|
||||
in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(handle,
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
tunable,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
#ifndef CONV_TUNABLE_FWD_V4R4_NCHW_KCYX_NKHW_HPP
|
||||
#define CONV_TUNABLE_FWD_V4R4_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
struct tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw
|
||||
{
|
||||
int32_t BlockSize;
|
||||
|
||||
int32_t MPerBlock;
|
||||
int32_t NPerBlock;
|
||||
int32_t KPerBlock;
|
||||
|
||||
int32_t M1PerThread;
|
||||
int32_t N1PerThread;
|
||||
int32_t KPerThread;
|
||||
|
||||
int32_t M1N1ThreadClusterM10;
|
||||
int32_t M1N1ThreadClusterN10;
|
||||
int32_t M1N1ThreadClusterM11;
|
||||
int32_t M1N1ThreadClusterN11;
|
||||
|
||||
std::array<int32_t, 3> ABlockTransferThreadSliceLengths_K_M0_M1;
|
||||
std::array<int32_t, 3> ABlockTransferThreadClusterLengths_K_M0_M1;
|
||||
std::array<int32_t, 3> ABlockTransferThreadClusterArrangeOrder;
|
||||
std::array<int32_t, 3> ABlockTransferSrcAccessOrder;
|
||||
int32_t ABlockTransferSrcVectorDim;
|
||||
int32_t ABlockTransferSrcScalarPerVector;
|
||||
int32_t ABlockTransferDstScalarPerVector_M1;
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<int32_t, 3> BBlockTransferThreadSliceLengths_K_N0_N1;
|
||||
std::array<int32_t, 3> BBlockTransferThreadClusterLengths_K_N0_N1;
|
||||
std::array<int32_t, 3> BBlockTransferThreadClusterArrangeOrder;
|
||||
std::array<int32_t, 3> BBlockTransferSrcAccessOrder;
|
||||
int32_t BBlockTransferSrcVectorDim;
|
||||
int32_t BBlockTransferSrcScalarPerVector;
|
||||
int32_t BBlockTransferDstScalarPerVector_N1;
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<int32_t, 6> CThreadTransferSrcDstAccessOrder;
|
||||
int32_t CThreadTransferSrcDstVectorDim;
|
||||
int32_t CThreadTransferDstScalarPerVector;
|
||||
};
|
||||
|
||||
static tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw default_tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw = {
|
||||
256, 128, 128, 8, 4, 4, 1,
|
||||
8, 8, 2, 2, {4, 1, 1}, {2, 1, 128}, {2, 1, 0},
|
||||
{2, 1, 0}, 0, 4, 1, false, {4, 1, 1}, {2, 1, 128},
|
||||
{0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2},
|
||||
5, 1};
|
||||
#endif
|
||||
@@ -0,0 +1,73 @@
|
||||
#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
|
||||
#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
|
||||
{
|
||||
int32_t BlockSize;
|
||||
|
||||
int32_t MPerBlock;
|
||||
int32_t NPerBlock;
|
||||
int32_t KPerBlock;
|
||||
|
||||
int32_t MPerWave;
|
||||
int32_t NPerWave;
|
||||
int32_t K1;
|
||||
|
||||
int32_t MRepeat;
|
||||
int32_t NRepeat;
|
||||
|
||||
std::array<int32_t, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
|
||||
std::array<int32_t, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
|
||||
std::array<int32_t, 3> ABlockTransferThreadClusterArrangeOrder;
|
||||
std::array<int32_t, 3> ABlockTransferSrcAccessOrder;
|
||||
int32_t ABlockTransferSrcVectorDim;
|
||||
int32_t ABlockTransferSrcScalarPerVector;
|
||||
int32_t ABlockTransferDstScalarPerVector_K1;
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<int32_t, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
|
||||
std::array<int32_t, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
|
||||
std::array<int32_t, 3> BBlockTransferThreadClusterArrangeOrder;
|
||||
std::array<int32_t, 3> BBlockTransferSrcAccessOrder;
|
||||
int32_t BBlockTransferSrcVectorDim;
|
||||
int32_t BBlockTransferSrcScalarPerVector;
|
||||
int32_t BBlockTransferDstScalarPerVector_K1;
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<int32_t, 8> CThreadTransferSrcDstAccessOrder;
|
||||
int32_t CThreadTransferSrcDstVectorDim;
|
||||
int32_t CThreadTransferDstScalarPerVector;
|
||||
};
|
||||
|
||||
static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
|
||||
default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw = {
|
||||
256, // BlockSize
|
||||
128, // MPerBlock,
|
||||
128, // NPerBlock,
|
||||
4, // KPerBlock,
|
||||
32, // MPerWave,
|
||||
32, // NPerWave,
|
||||
4, // K1,
|
||||
2, // MRepeat,
|
||||
2, // NRepeat,
|
||||
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector,
|
||||
4, // ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
{0, 2, 1}, // BBlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
|
||||
1, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
4, // BBlockTransferDstScalarPerVector_K1
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun
|
||||
{3, 0, 1, 2, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
|
||||
7, // CThreadTransferSrcDstVectorDim,
|
||||
1 // CThreadTransferDstScalarPerVector
|
||||
};
|
||||
#endif
|
||||
@@ -0,0 +1,73 @@
|
||||
#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP
|
||||
#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
struct tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
|
||||
{
|
||||
int32_t BlockSize;
|
||||
|
||||
int32_t MPerBlock;
|
||||
int32_t NPerBlock;
|
||||
int32_t KPerBlock;
|
||||
|
||||
int32_t MPerWave;
|
||||
int32_t NPerWave;
|
||||
int32_t K1;
|
||||
|
||||
int32_t MRepeat;
|
||||
int32_t NRepeat;
|
||||
|
||||
std::array<int32_t, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
|
||||
std::array<int32_t, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
|
||||
std::array<int32_t, 3> ABlockTransferThreadClusterArrangeOrder;
|
||||
std::array<int32_t, 3> ABlockTransferSrcAccessOrder;
|
||||
int32_t ABlockTransferSrcVectorDim;
|
||||
int32_t ABlockTransferSrcScalarPerVector;
|
||||
int32_t ABlockTransferDstScalarPerVector_K1;
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<int32_t, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
|
||||
std::array<int32_t, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
|
||||
std::array<int32_t, 3> BBlockTransferThreadClusterArrangeOrder;
|
||||
std::array<int32_t, 3> BBlockTransferSrcAccessOrder;
|
||||
int32_t BBlockTransferSrcVectorDim;
|
||||
int32_t BBlockTransferSrcScalarPerVector;
|
||||
int32_t BBlockTransferDstScalarPerVector_K1;
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<int32_t, 8> CThreadTransferSrcDstAccessOrder;
|
||||
int32_t CThreadTransferSrcDstVectorDim;
|
||||
int32_t CThreadTransferDstScalarPerVector;
|
||||
};
|
||||
|
||||
static tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
|
||||
default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk = {
|
||||
256, // BlockSize
|
||||
128, // MPerBlock,
|
||||
128, // NPerBlock,
|
||||
4, // KPerBlock,
|
||||
32, // MPerWave,
|
||||
32, // NPerWave,
|
||||
4, // K1,
|
||||
2, // MRepeat,
|
||||
2, // NRepeat,
|
||||
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
4, // ABlockTransferSrcScalarPerVector,
|
||||
4, // ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
{1, 0, 2}, // BBlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
4, // BBlockTransferSrcScalarPerVector
|
||||
4, // BBlockTransferDstScalarPerVector_K1
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun
|
||||
{2, 3, 0, 1, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
|
||||
7, // CThreadTransferSrcDstVectorDim,
|
||||
1 // CThreadTransferDstScalarPerVector
|
||||
};
|
||||
#endif
|
||||
@@ -0,0 +1,42 @@
|
||||
#ifndef CONV_TUNABLE_FWD_V6R1_NCHW_KCYX_NKHW_HPP
|
||||
#define CONV_TUNABLE_FWD_V6R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
struct tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw
|
||||
{
|
||||
int32_t BlockSize = 256;
|
||||
|
||||
int32_t GN0 = 4;
|
||||
int32_t GK1 = 1;
|
||||
|
||||
int32_t GM1PerBlockGM11 = 128;
|
||||
int32_t GN1PerBlockGN11 = 32;
|
||||
int32_t GK0PerBlock = 8;
|
||||
|
||||
int32_t BM1PerThreadBM11 = 4;
|
||||
int32_t BN1PerThreadBN11 = 4;
|
||||
int32_t BK0PerThread = 1;
|
||||
|
||||
int32_t BM10BN10ThreadClusterBM100 = 2;
|
||||
int32_t BM10BN10ThreadClusterBN100 = 2;
|
||||
int32_t BM10BN10ThreadClusterBM101 = 8;
|
||||
int32_t BM10BN10ThreadClusterBN101 = 8;
|
||||
|
||||
std::array<int32_t, 5> ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = {4, 1, 1, 1, 1};
|
||||
std::array<int32_t, 5> ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = {
|
||||
2, 1, 1, 128, 1};
|
||||
std::array<int32_t, 5> ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = {
|
||||
4, 1, 1, 1, 1};
|
||||
std::array<int32_t, 5> ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = {
|
||||
1, 1, 1, 1, 1};
|
||||
|
||||
std::array<int32_t, 5> BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = {1, 4, 1, 1, 1};
|
||||
std::array<int32_t, 5> BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = {
|
||||
8, 1, 1, 32, 1};
|
||||
std::array<int32_t, 5> BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = {
|
||||
1, 1, 1, 1, 1};
|
||||
std::array<int32_t, 5> BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = {
|
||||
1, 1, 1, 1, 1};
|
||||
|
||||
int32_t CThreadTransferDstScalarPerVector = 1;
|
||||
};
|
||||
#endif
|
||||
@@ -1,13 +1,11 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "handle.hpp"
|
||||
#include "online_driver_common.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
#include "olc_driver_common.hpp"
|
||||
#include "conv_tunables.hpp"
|
||||
|
||||
#include "handle.hpp"
|
||||
#include "conv_tunable_fwd_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw {
|
||||
|
||||
@@ -211,7 +209,7 @@ template <typename TInWei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw_olc(
|
||||
void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
|
||||
olCompile::Handle* handle,
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
@@ -1,12 +1,10 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "handle.hpp"
|
||||
#include "online_driver_common.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
|
||||
#include "olc_driver_common.hpp"
|
||||
#include "conv_tunables.hpp"
|
||||
|
||||
#include "handle.hpp"
|
||||
#include "conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw {
|
||||
|
||||
@@ -208,7 +206,7 @@ template <typename TInWei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_olc(
|
||||
void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
|
||||
olCompile::Handle* handle,
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
@@ -1,13 +1,11 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "handle.hpp"
|
||||
#include "online_driver_common.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
#include "olc_driver_common.hpp"
|
||||
#include "conv_tunables.hpp"
|
||||
|
||||
#include "handle.hpp"
|
||||
#include "conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk {
|
||||
|
||||
@@ -209,7 +207,7 @@ template <typename TInWei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_olc(
|
||||
void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
olCompile::Handle* handle,
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
@@ -0,0 +1,425 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "handle.hpp"
|
||||
#include "online_driver_common.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "conv_tunable_fwd_v6r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
namespace detail_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw {
|
||||
|
||||
template <typename TInWei, typename TAcc, typename TOut>
|
||||
static std::string get_network_config_string_from_types()
|
||||
{
|
||||
std::string out("DAT_");
|
||||
|
||||
out += static_cast<char>(Driver::get_typeid_from_type<TInWei>()) +
|
||||
static_cast<char>(Driver::get_typeid_from_type<TAcc>()) +
|
||||
static_cast<char>(Driver::get_typeid_from_type<TOut>());
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
static std::string
|
||||
get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw& tunable)
|
||||
{
|
||||
std::string out("TUN_");
|
||||
|
||||
out += std::to_string(tunable.BlockSize) + "_";
|
||||
|
||||
out += std::to_string(tunable.GN0) + "x" + std::to_string(tunable.GK1) + "_";
|
||||
|
||||
out += std::to_string(tunable.GM1PerBlockGM11) + "x" + std::to_string(tunable.GN1PerBlockGN11) +
|
||||
"x" + std::to_string(tunable.GK0PerBlock) + "_";
|
||||
|
||||
out += std::to_string(tunable.BM1PerThreadBM11) + "x" +
|
||||
std::to_string(tunable.BN1PerThreadBN11) + "x" + std::to_string(tunable.BK0PerThread) +
|
||||
"_";
|
||||
|
||||
out += std::to_string(tunable.BM10BN10ThreadClusterBM100) + "x" +
|
||||
std::to_string(tunable.BM10BN10ThreadClusterBN100) + "x" +
|
||||
std::to_string(tunable.BM10BN10ThreadClusterBM101) + "x" +
|
||||
std::to_string(tunable.BM10BN10ThreadClusterBN101) + "_";
|
||||
|
||||
out += std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[0]) + "x" +
|
||||
std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[1]) + "x" +
|
||||
std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[2]) + "x" +
|
||||
std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[3]) + "x" +
|
||||
std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[4]) + "_";
|
||||
|
||||
out +=
|
||||
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[0]) + "x" +
|
||||
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[1]) + "x" +
|
||||
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[2]) + "x" +
|
||||
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[3]) + "x" +
|
||||
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[4]) + "_";
|
||||
|
||||
out += std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) +
|
||||
"x" +
|
||||
std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) +
|
||||
"x" +
|
||||
std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) +
|
||||
"x" +
|
||||
std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) +
|
||||
"x" +
|
||||
std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]) +
|
||||
"_";
|
||||
|
||||
out += std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) +
|
||||
"x" +
|
||||
std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) +
|
||||
"x" +
|
||||
std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) +
|
||||
"x" +
|
||||
std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) +
|
||||
"x" +
|
||||
std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]) +
|
||||
"_";
|
||||
|
||||
out += std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[0]) + "x" +
|
||||
std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[1]) + "x" +
|
||||
std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[2]) + "x" +
|
||||
std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[3]) + "x" +
|
||||
std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[4]) + "_";
|
||||
|
||||
out +=
|
||||
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[0]) + "x" +
|
||||
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[1]) + "x" +
|
||||
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[2]) + "x" +
|
||||
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[3]) + "x" +
|
||||
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[4]) + "_";
|
||||
|
||||
out += std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) +
|
||||
"x" +
|
||||
std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) +
|
||||
"x" +
|
||||
std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) +
|
||||
"x" +
|
||||
std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) +
|
||||
"x" +
|
||||
std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]) +
|
||||
"_";
|
||||
|
||||
out += std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) +
|
||||
"x" +
|
||||
std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) +
|
||||
"x" +
|
||||
std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) +
|
||||
"x" +
|
||||
std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) +
|
||||
"x" +
|
||||
std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]) +
|
||||
"_";
|
||||
|
||||
out += std::to_string(tunable.CThreadTransferDstScalarPerVector);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
template <typename TInWei, typename TAcc, typename TOut>
|
||||
static std::string get_definition_string_from_types()
|
||||
{
|
||||
std::string out;
|
||||
|
||||
out += " -DCK_PARAM_IN_WEI_DATATYPE=" + std::to_string(Driver::get_typeid_from_type<TInWei>()) +
|
||||
" -DCK_PARAM_ACC_DATATYPE=" + std::to_string(Driver::get_typeid_from_type<TAcc>()) +
|
||||
" -DCK_PARAM_OUT_DATATYPE=" + std::to_string(Driver::get_typeid_from_type<TOut>());
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
static std::string
|
||||
get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw& tunable)
|
||||
{
|
||||
std::string out;
|
||||
|
||||
out += " -DCK_PARAM_BlockSize=" + std::to_string(tunable.BlockSize);
|
||||
|
||||
out += " -DCK_PARAM_GN0=" + std::to_string(tunable.GN0);
|
||||
out += " -DCK_PARAM_GK1=" + std::to_string(tunable.GK1);
|
||||
|
||||
out += " -DCK_PARAM_GM1PerBlockGM11=" + std::to_string(tunable.GM1PerBlockGM11) +
|
||||
" -DCK_PARAM_GN1PerBlockGN11=" + std::to_string(tunable.GN1PerBlockGN11) +
|
||||
" -DCK_PARAM_GK0PerBlock=" + std::to_string(tunable.GK0PerBlock);
|
||||
|
||||
out += " -DCK_PARAM_BM1PerThreadBM11=" + std::to_string(tunable.BM1PerThreadBM11) +
|
||||
" -DCK_PARAM_BN1PerThreadBN11=" + std::to_string(tunable.BN1PerThreadBN11) +
|
||||
" -DCK_PARAM_BK0PerThread=" + std::to_string(tunable.BK0PerThread);
|
||||
|
||||
out += " -DCK_PARAM_BM10BN10ThreadClusterBM100=" +
|
||||
std::to_string(tunable.BM10BN10ThreadClusterBM100) +
|
||||
" -DCK_PARAM_BM10BN10ThreadClusterBN100=" +
|
||||
std::to_string(tunable.BM10BN10ThreadClusterBN100) +
|
||||
" -DCK_PARAM_BM10BN10ThreadClusterBM101=" +
|
||||
std::to_string(tunable.BM10BN10ThreadClusterBM101) +
|
||||
" -DCK_PARAM_BM10BN10ThreadClusterBN101=" +
|
||||
std::to_string(tunable.BM10BN10ThreadClusterBN101);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1=" +
|
||||
std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[0]) + "," +
|
||||
std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[1]) + "," +
|
||||
std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[2]) + "," +
|
||||
std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[3]) + "," +
|
||||
std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[4]);
|
||||
|
||||
out +=
|
||||
" -DCK_PARAM_ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1=" +
|
||||
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[0]) + "," +
|
||||
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[1]) + "," +
|
||||
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[2]) + "," +
|
||||
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[3]) + "," +
|
||||
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[4]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" +
|
||||
std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) +
|
||||
"," +
|
||||
std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) +
|
||||
"," +
|
||||
std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) +
|
||||
"," +
|
||||
std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) +
|
||||
"," +
|
||||
std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" +
|
||||
std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) +
|
||||
"," +
|
||||
std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) +
|
||||
"," +
|
||||
std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) +
|
||||
"," +
|
||||
std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) +
|
||||
"," +
|
||||
std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1=" +
|
||||
std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[0]) + "," +
|
||||
std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[1]) + "," +
|
||||
std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[2]) + "," +
|
||||
std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[3]) + "," +
|
||||
std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[4]);
|
||||
|
||||
out +=
|
||||
" -DCK_PARAM_BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1=" +
|
||||
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[0]) + "," +
|
||||
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[1]) + "," +
|
||||
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[2]) + "," +
|
||||
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[3]) + "," +
|
||||
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[4]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" +
|
||||
std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) +
|
||||
"," +
|
||||
std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) +
|
||||
"," +
|
||||
std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) +
|
||||
"," +
|
||||
std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) +
|
||||
"," +
|
||||
std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" +
|
||||
std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) +
|
||||
"," +
|
||||
std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) +
|
||||
"," +
|
||||
std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) +
|
||||
"," +
|
||||
std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) +
|
||||
"," +
|
||||
std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]);
|
||||
|
||||
out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" +
|
||||
std::to_string(tunable.CThreadTransferDstScalarPerVector);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
} // namespace detail_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw(
|
||||
olCompile::Handle* handle,
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
const tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw& tunable,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
using namespace detail_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw;
|
||||
using size_t = std::size_t;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// The follow codes are only used for computing the grid_size, hasMainKBlockLoop,
|
||||
// hasDoubleTailKBlockLoop
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
tunable.GN0,
|
||||
tunable.GK1);
|
||||
|
||||
const auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
|
||||
const auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
|
||||
|
||||
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
|
||||
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
|
||||
const auto GK = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
|
||||
|
||||
const index_t grid_size = (GM1 / tunable.GM1PerBlockGM11) * (GN1 / tunable.GN1PerBlockGN11);
|
||||
const bool hasMainKBlockLoop = ((GK + tunable.GK0PerBlock) / (2 * tunable.GK0PerBlock) > 1);
|
||||
const bool hasDoubleTailKBlockLoop = ((GK / tunable.GK0PerBlock) % 2 == 0);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// these buffers are usually provided by the user application
|
||||
DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
// these are workspace buffers that should be expressed to the user by the corresponding
|
||||
// workspace API
|
||||
DeviceMem workspace_buf(4096);
|
||||
|
||||
void* a_grid_desc_gk0_gm0_gm10_gm11_gk1_dev_buf = workspace_buf.GetDeviceBuffer();
|
||||
void* b_grid_desc_gk0_gn0_gn10_gn11_gk1_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 1024);
|
||||
void* c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 2048);
|
||||
void* c_grid_block_cluster_blockid_to_gm10_gn10_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 3072);
|
||||
|
||||
const std::vector<size_t> vld = {static_cast<size_t>(tunable.BlockSize), 1, 1};
|
||||
const std::vector<size_t> vgd1 = {static_cast<size_t>(tunable.BlockSize), 1, 1};
|
||||
const std::vector<size_t> vgd2 = {static_cast<size_t>(grid_size * tunable.BlockSize), 1, 1};
|
||||
|
||||
std::string program_name = "dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.cpp";
|
||||
std::string algo_name = "implicit_gemm_conv_fwd_v6r1_nchw";
|
||||
|
||||
std::string param = " -std=c++17 ";
|
||||
std::string network_config;
|
||||
|
||||
param += get_definition_string_from_types<TInWei, TAcc, TOut>() +
|
||||
" -DCK_PARAM_HAS_MAIN_KBLOCK_LOOP=" + std::to_string(hasMainKBlockLoop) +
|
||||
" -DCK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP=" + std::to_string(hasDoubleTailKBlockLoop) +
|
||||
get_definition_string_from_tunable(tunable);
|
||||
|
||||
network_config = get_network_config_string_from_types<TInWei, TAcc, TOut>() + "_" +
|
||||
std::to_string(hasDoubleTailKBlockLoop) + "_" +
|
||||
get_network_config_string_from_tunable(tunable);
|
||||
|
||||
std::vector<float> kernel1_times;
|
||||
std::vector<float> kernel2_times;
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
KernelTimer timer1, timer2;
|
||||
std::string kernel_name;
|
||||
|
||||
kernel_name = "dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw_prepare";
|
||||
auto network_config_1 = network_config + "_1";
|
||||
|
||||
timer1.Start();
|
||||
handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)(
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I0]),
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I1]),
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I2]),
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I3]),
|
||||
static_cast<index_t>(wei_k_c_y_x_lengths[I0]),
|
||||
static_cast<index_t>(wei_k_c_y_x_lengths[I2]),
|
||||
static_cast<index_t>(wei_k_c_y_x_lengths[I3]),
|
||||
conv_strides[I0],
|
||||
conv_strides[I1],
|
||||
conv_dilations[I0],
|
||||
conv_dilations[I1],
|
||||
in_left_pads[I0],
|
||||
in_left_pads[I1],
|
||||
in_right_pads[I0],
|
||||
in_right_pads[I1],
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1_dev_buf,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1_dev_buf,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1_dev_buf,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10_dev_buf);
|
||||
timer2.End();
|
||||
|
||||
kernel_name = "dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw";
|
||||
auto network_config_2 = network_config + "_2";
|
||||
|
||||
timer2.Start();
|
||||
handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)(
|
||||
reinterpret_cast<const TInWei*>(wei_k_c_y_x_dev_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<const TInWei*>(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<TOut*>(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()),
|
||||
(const void*)(a_grid_desc_gk0_gm0_gm10_gm11_gk1_dev_buf),
|
||||
(const void*)(b_grid_desc_gk0_gn0_gn10_gn11_gk1_dev_buf),
|
||||
(const void*)(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1_dev_buf),
|
||||
(const void*)(c_grid_block_cluster_blockid_to_gm10_gn10_dev_buf));
|
||||
timer2.End();
|
||||
|
||||
kernel1_times.push_back(timer1.GetElapsedTime());
|
||||
kernel2_times.push_back(timer2.GetElapsedTime());
|
||||
}
|
||||
|
||||
{
|
||||
auto ave_time1 = Driver::get_effective_average(kernel1_times);
|
||||
auto ave_time2 = Driver::get_effective_average(kernel2_times);
|
||||
|
||||
const auto N = in_n_c_hi_wi_lengths[I0];
|
||||
const auto C = in_n_c_hi_wi_lengths[I1];
|
||||
|
||||
const auto K = out_n_k_ho_wo_lengths[I1];
|
||||
const auto Ho = out_n_k_ho_wo_lengths[I2];
|
||||
const auto Wo = out_n_k_ho_wo_lengths[I3];
|
||||
|
||||
const auto Y = wei_k_c_y_x_lengths[I2];
|
||||
const auto X = wei_k_c_y_x_lengths[I3];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2);
|
||||
|
||||
std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", "
|
||||
<< ave_time2 << "), " << perf << " TFlop/s" << std::endl;
|
||||
};
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
19
host/host_tensor/CMakeLists.txt
Normal file
19
host/host_tensor/CMakeLists.txt
Normal file
@@ -0,0 +1,19 @@
|
||||
include_directories(BEFORE
|
||||
include
|
||||
)
|
||||
|
||||
set(HOST_TENSOR_SOURCE
|
||||
src/host_tensor.cpp;
|
||||
src/device.cpp;
|
||||
)
|
||||
|
||||
## the library target
|
||||
add_library(host_tensor SHARED ${HOST_TENSOR_SOURCE})
|
||||
|
||||
target_link_libraries(host_tensor PRIVATE hip::device)
|
||||
target_link_libraries(host_tensor INTERFACE hip::host)
|
||||
|
||||
target_compile_features(host_tensor PUBLIC)
|
||||
set_target_properties(host_tensor PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
install(TARGETS host_tensor LIBRARY DESTINATION lib)
|
||||
@@ -2,7 +2,8 @@
|
||||
#define DEVICE_HPP
|
||||
|
||||
#include <memory>
|
||||
#include "config.hpp"
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
|
||||
struct DeviceMem
|
||||
{
|
||||
@@ -30,7 +31,6 @@ struct KernelTimer
|
||||
std::unique_ptr<KernelTimerImpl> impl;
|
||||
};
|
||||
|
||||
#if CK_DEVICE_BACKEND_AMD
|
||||
using device_stream_t = hipStream_t;
|
||||
|
||||
template <typename... Args, typename F>
|
||||
@@ -83,44 +83,4 @@ float launch_and_time_kernel(F kernel,
|
||||
return timer.GetElapsedTime() / nrepeat;
|
||||
}
|
||||
|
||||
#elif CK_DEVICE_BACKEND_NVIDIA
|
||||
using device_stream_t = cudaStream_t;
|
||||
|
||||
template <typename... Args, typename F>
|
||||
void launch_kernel(F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
cudaStream_t stream_id,
|
||||
Args... args)
|
||||
{
|
||||
const void* f = reinterpret_cast<const void*>(kernel);
|
||||
void* p_args[] = {&args...};
|
||||
|
||||
cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, stream_id);
|
||||
}
|
||||
|
||||
template <typename... Args, typename F>
|
||||
float launch_and_time_kernel(F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
cudaStream_t stream_id,
|
||||
Args... args)
|
||||
{
|
||||
KernelTimer timer;
|
||||
|
||||
const void* f = reinterpret_cast<const void*>(kernel);
|
||||
void* p_args[] = {&args...};
|
||||
|
||||
timer.Start();
|
||||
|
||||
cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, stream_id);
|
||||
|
||||
timer.End();
|
||||
|
||||
return timer.GetElapsedTime();
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@@ -1,107 +1,59 @@
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
|
||||
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
|
||||
{
|
||||
#if CK_DEVICE_BACKEND_AMD
|
||||
hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
|
||||
#elif CK_DEVICE_BACKEND_NVIDIA
|
||||
cudaMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize);
|
||||
#endif
|
||||
}
|
||||
|
||||
void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; }
|
||||
|
||||
void DeviceMem::ToDevice(const void* p)
|
||||
{
|
||||
#if CK_DEVICE_BACKEND_AMD
|
||||
hipGetErrorString(
|
||||
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
|
||||
#elif CK_DEVICE_BACKEND_NVIDIA
|
||||
cudaMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, cudaMemcpyHostToDevice);
|
||||
#endif
|
||||
}
|
||||
|
||||
void DeviceMem::FromDevice(void* p)
|
||||
{
|
||||
#if CK_DEVICE_BACKEND_AMD
|
||||
hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
|
||||
#elif CK_DEVICE_BACKEND_NVIDIA
|
||||
cudaMemcpy(p, mpDeviceBuf, mMemSize, cudaMemcpyDeviceToHost);
|
||||
#endif
|
||||
}
|
||||
|
||||
DeviceMem::~DeviceMem()
|
||||
{
|
||||
#if CK_DEVICE_BACKEND_AMD
|
||||
hipGetErrorString(hipFree(mpDeviceBuf));
|
||||
#elif CK_DEVICE_BACKEND_NVIDIA
|
||||
cudaFree(mpDeviceBuf);
|
||||
#endif
|
||||
}
|
||||
DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); }
|
||||
|
||||
struct KernelTimerImpl
|
||||
{
|
||||
KernelTimerImpl()
|
||||
{
|
||||
#if CK_DEVICE_BACKEND_AMD
|
||||
hipEventCreate(&mStart);
|
||||
hipEventCreate(&mEnd);
|
||||
#elif CK_DEVICE_BACKEND_NVIDIA
|
||||
cudaEventCreate(&mStart);
|
||||
cudaEventCreate(&mEnd);
|
||||
#endif
|
||||
}
|
||||
|
||||
~KernelTimerImpl()
|
||||
{
|
||||
#if CK_DEVICE_BACKEND_AMD
|
||||
hipEventDestroy(mStart);
|
||||
hipEventDestroy(mEnd);
|
||||
#elif CK_DEVICE_BACKEND_NVIDIA
|
||||
cudaEventDestroy(mStart);
|
||||
cudaEventDestroy(mEnd);
|
||||
#endif
|
||||
}
|
||||
|
||||
void Start()
|
||||
{
|
||||
#if CK_DEVICE_BACKEND_AMD
|
||||
hipDeviceSynchronize();
|
||||
hipEventRecord(mStart, 0);
|
||||
#elif CK_DEVICE_BACKEND_NVIDIA
|
||||
cudaDeviceSynchronize();
|
||||
cudaEventRecord(mStart, 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
void End()
|
||||
{
|
||||
#if CK_DEVICE_BACKEND_AMD
|
||||
hipEventRecord(mEnd, 0);
|
||||
hipEventSynchronize(mEnd);
|
||||
#elif CK_DEVICE_BACKEND_NVIDIA
|
||||
cudaEventRecord(mEnd, 0);
|
||||
cudaEventSynchronize(mEnd);
|
||||
#endif
|
||||
}
|
||||
|
||||
float GetElapsedTime() const
|
||||
{
|
||||
float time;
|
||||
#if CK_DEVICE_BACKEND_AMD
|
||||
hipEventElapsedTime(&time, mStart, mEnd);
|
||||
#elif CK_DEVICE_BACKEND_NVIDIA
|
||||
cudaEventElapsedTime(&time, mStart, mEnd);
|
||||
#endif
|
||||
return time;
|
||||
}
|
||||
|
||||
#if CK_DEVICE_BACKEND_AMD
|
||||
hipEvent_t mStart, mEnd;
|
||||
#elif CK_DEVICE_BACKEND_NVIDIA
|
||||
cudaEvent_t mStart, mEnd;
|
||||
#endif
|
||||
};
|
||||
|
||||
KernelTimer::KernelTimer() : impl(new KernelTimerImpl()) {}
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
set(CMAKE_CXX_COMPILER /opt/rocm/llvm/bin/clang++)
|
||||
|
||||
## for online-compiling of HIP kernels
|
||||
@@ -17,6 +16,7 @@ if(OLC_HIP_COMPILER MATCHES ".*clang\\+\\+$")
|
||||
${CMAKE_INSTALL_PREFIX}/llvm
|
||||
)
|
||||
endif()
|
||||
|
||||
if(OLC_OFFLOADBUNDLER_BIN)
|
||||
message(STATUS "clang-offload-bundler found: ${OLC_OFFLOADBUNDLER_BIN}")
|
||||
set(OLC_OFFLOADBUNDLER_BIN "${OLC_OFFLOADBUNDLER_BIN}")
|
||||
@@ -67,92 +67,58 @@ else()
|
||||
set(OLC_DEBUG 0)
|
||||
endif()
|
||||
|
||||
configure_file("${CMAKE_CURRENT_SOURCE_DIR}/olCompiling/include/config.h.in" "${CMAKE_CURRENT_SOURCE_DIR}/olCompiling/include/config.h")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/host/online_compilation/include/config.h.in" "${PROJECT_BINARY_DIR}/host/online_compilation/include/config.h")
|
||||
|
||||
include_directories(BEFORE
|
||||
${PROJECT_BINARY_DIR}/host/online_compilation/include
|
||||
)
|
||||
|
||||
message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}")
|
||||
|
||||
## HIP_COMPILER_FLAGS will be used for on-line compiling of the HIP kernels
|
||||
add_definitions("-DHIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}")
|
||||
|
||||
file(GLOB COMPOSABLE_KERNEL_INCLUDE_1 "${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm/*.hpp")
|
||||
file(GLOB COMPOSABLE_KERNEL_INCLUDE_2 "${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description/*.hpp")
|
||||
file(GLOB COMPOSABLE_KERNEL_INCLUDE_3 "${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation/*.hpp")
|
||||
file(GLOB COMPOSABLE_KERNEL_INCLUDE_4 "${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/*.hpp")
|
||||
file(GLOB COMPOSABLE_KERNEL_INCLUDE_5 "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/*.hpp")
|
||||
file(GLOB COMPOSABLE_KERNEL_INCLUDE_6 "${PROJECT_SOURCE_DIR}/external/rocm/include/bfloat16_dev.hpp")
|
||||
file(GLOB_RECURSE COMPOSABLE_KERNEL_INCLUDE_1 "${PROJECT_SOURCE_DIR}/composable_kernel/include/*/*.hpp")
|
||||
file(GLOB COMPOSABLE_KERNEL_INCLUDE_2 "${PROJECT_SOURCE_DIR}/external/rocm/include/bfloat16_dev.hpp")
|
||||
set(MCONV_KERNEL_INCLUDES
|
||||
${COMPOSABLE_KERNEL_INCLUDE_1}
|
||||
${COMPOSABLE_KERNEL_INCLUDE_2}
|
||||
${COMPOSABLE_KERNEL_INCLUDE_3}
|
||||
${COMPOSABLE_KERNEL_INCLUDE_4}
|
||||
${COMPOSABLE_KERNEL_INCLUDE_5}
|
||||
${COMPOSABLE_KERNEL_INCLUDE_6}
|
||||
)
|
||||
|
||||
set(MCONV_KERNELS
|
||||
../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.cpp
|
||||
../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.cpp
|
||||
../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp
|
||||
../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp
|
||||
)
|
||||
file(GLOB_RECURSE MCONV_KERNELS "${PROJECT_SOURCE_DIR}/composable_kernel/src/kernel_wrapper/*.cpp")
|
||||
|
||||
add_kernels("olCompiling/" "${MCONV_KERNELS}")
|
||||
add_kernel_includes("olCompiling/" "${MCONV_KERNEL_INCLUDES}")
|
||||
add_kernels(${CMAKE_CURRENT_SOURCE_DIR} "${MCONV_KERNELS}")
|
||||
add_kernel_includes(${CMAKE_CURRENT_SOURCE_DIR} "${MCONV_KERNEL_INCLUDES}")
|
||||
|
||||
set(MCONV_SOURCES
|
||||
src/host_tensor.cpp;
|
||||
src/device.cpp;
|
||||
set(ONLINE_COMPILATION_SOURCE
|
||||
${PROJECT_BINARY_DIR}/kernel.cpp
|
||||
${PROJECT_BINARY_DIR}/kernel_includes.cpp
|
||||
)
|
||||
|
||||
set(OLC_HIP_UTILITY_HEADERS
|
||||
olCompiling/include/config.h
|
||||
olCompiling/include/logger.hpp
|
||||
olCompiling/include/stringutils.hpp
|
||||
olCompiling/include/tmp_dir.hpp
|
||||
olCompiling/include/write_file.hpp
|
||||
olCompiling/include/env.hpp
|
||||
olCompiling/include/manage_ptr.hpp
|
||||
olCompiling/include/md5.hpp
|
||||
olCompiling/include/simple_hash.hpp
|
||||
olCompiling/include/exec_utils.hpp
|
||||
olCompiling/include/hipCheck.hpp
|
||||
olCompiling/include/target_properties.hpp
|
||||
olCompiling/include/handle.hpp
|
||||
olCompiling/include/op_kernel_args.hpp
|
||||
olCompiling/include/kernel.hpp
|
||||
olCompiling/include/kernel_build_params.hpp
|
||||
olCompiling/include/hip_build_utils.hpp
|
||||
olCompiling/include/hipoc_program.hpp
|
||||
olCompiling/include/hipoc_program_impl.hpp
|
||||
olCompiling/include/hipoc_kernel.hpp
|
||||
olCompiling/include/kernel_cache.hpp
|
||||
olCompiling/include/binary_cache.hpp
|
||||
)
|
||||
include_directories(BEFORE
|
||||
${PROJECT_BINARY_DIR}/host/online_compilation/include
|
||||
include
|
||||
)
|
||||
|
||||
set(OLC_HIP_UTILITY_CPPS
|
||||
olCompiling/hip_utility/logger.cpp
|
||||
olCompiling/hip_utility/tmp_dir.cpp
|
||||
olCompiling/hip_utility/md5.cpp
|
||||
olCompiling/hip_utility/exec_utils.cpp
|
||||
olCompiling/hip_utility/target_properties.cpp
|
||||
olCompiling/hip_utility/handlehip.cpp
|
||||
olCompiling/hip_utility/kernel_build_params.cpp
|
||||
olCompiling/hip_utility/hip_build_utils.cpp
|
||||
olCompiling/hip_utility/hipoc_program.cpp
|
||||
olCompiling/hip_utility/hipoc_kernel.cpp
|
||||
olCompiling/hip_utility/kernel_cache.cpp
|
||||
olCompiling/hip_utility/binary_cache.cpp
|
||||
hip_utility/logger.cpp
|
||||
hip_utility/tmp_dir.cpp
|
||||
hip_utility/md5.cpp
|
||||
hip_utility/exec_utils.cpp
|
||||
hip_utility/target_properties.cpp
|
||||
hip_utility/handlehip.cpp
|
||||
hip_utility/kernel_build_params.cpp
|
||||
hip_utility/hip_build_utils.cpp
|
||||
hip_utility/hipoc_program.cpp
|
||||
hip_utility/hipoc_kernel.cpp
|
||||
hip_utility/kernel_cache.cpp
|
||||
hip_utility/binary_cache.cpp
|
||||
)
|
||||
|
||||
list(APPEND OLC_SOURCES ${OLC_HIP_UTILITY_CPPS} ${OLC_HIP_UTILITY_HEADERS})
|
||||
|
||||
list(INSERT MCONV_SOURCES 0
|
||||
${PROJECT_BINARY_DIR}/kernel.cpp
|
||||
${PROJECT_BINARY_DIR}/kernel_includes.cpp
|
||||
)
|
||||
|
||||
## addkernels provide the tool to create inlined kernels in one header
|
||||
add_subdirectory(olCompiling/addkernels)
|
||||
add_subdirectory(addkernels)
|
||||
|
||||
function(inline_kernels_src KERNELS KERNEL_INCLUDES)
|
||||
set(KERNEL_SRC_HPP_FILENAME batch_all.cpp.hpp)
|
||||
@@ -166,7 +132,7 @@ function(inline_kernels_src KERNELS KERNEL_INCLUDES)
|
||||
COMMAND $<TARGET_FILE:addkernels> -target ${KERNEL_SRC_HPP_PATH} -extern -source ${KERNELS}
|
||||
COMMENT "Inlining All kernels"
|
||||
)
|
||||
configure_file(olCompiling/kernels_batch.cpp.in ${KERNEL_SRC_CPP_PATH})
|
||||
configure_file(kernels_batch.cpp.in ${KERNEL_SRC_CPP_PATH})
|
||||
list(APPEND OLC_SOURCES ${KERNEL_SRC_CPP_PATH} ${KERNEL_SRC_HPP_PATH})
|
||||
|
||||
set(OLC_SOURCES ${OLC_SOURCES} PARENT_SCOPE)
|
||||
@@ -174,7 +140,7 @@ endfunction()
|
||||
|
||||
inline_kernels_src("${MCONV_KERNELS}" "${MCONV_KERNEL_INCLUDES}")
|
||||
|
||||
list(APPEND MCONV_SOURCES ${OLC_SOURCES} ${PROJECT_BINARY_DIR}/olc_kernel_includes.h)
|
||||
list(APPEND ONLINE_COMPILATION_SOURCE ${OLC_SOURCES} ${PROJECT_BINARY_DIR}/olc_kernel_includes.h)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${PROJECT_BINARY_DIR}/olc_kernel_includes.h
|
||||
@@ -185,19 +151,17 @@ add_custom_command(
|
||||
)
|
||||
|
||||
## the library target
|
||||
add_library(modConv SHARED ${MCONV_SOURCES})
|
||||
add_library(online_compilation SHARED ${ONLINE_COMPILATION_SOURCE})
|
||||
|
||||
target_include_directories(modConv PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/olCompiling/include/)
|
||||
target_include_directories(modConv PRIVATE ${PROJECT_BINARY_DIR})
|
||||
target_include_directories(modConv PRIVATE ${PROJECT_SOURCE_DIR}/external/half/include/)
|
||||
target_include_directories(online_compilation PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/online_compilation/include/)
|
||||
target_include_directories(online_compilation PRIVATE ${PROJECT_BINARY_DIR})
|
||||
target_include_directories(online_compilation PRIVATE ${PROJECT_SOURCE_DIR}/external/half/include/)
|
||||
|
||||
target_link_libraries(modConv PRIVATE hip::device)
|
||||
target_link_libraries(modConv INTERFACE hip::host)
|
||||
target_link_libraries(modConv PRIVATE Boost::filesystem)
|
||||
target_link_libraries(online_compilation PRIVATE hip::device)
|
||||
target_link_libraries(online_compilation INTERFACE hip::host)
|
||||
target_link_libraries(online_compilation PRIVATE Boost::filesystem)
|
||||
|
||||
target_compile_options(modConv PRIVATE -mfma)
|
||||
target_compile_features(online_compilation PUBLIC)
|
||||
set_target_properties(online_compilation PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
target_compile_features(modConv PUBLIC)
|
||||
set_target_properties(modConv PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
install(TARGETS modConv LIBRARY DESTINATION lib)
|
||||
install(TARGETS online_compilation LIBRARY DESTINATION lib)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user