mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Restructure gridwise and blockwise GEMM, add tensor contraction and FWD-v4r5 (#36)
* experimenting magic number division
* overhauling fwd-v4r4 to clearly reflect transformation graph
* added fwd-v4r5
* bug fix for make_dynamic_naive_tensor_descriptor_aligned_v2
* bug fix and added sanity-check in transform_dynamic_tensor_descriptor
* added conv_driver_v2
[ROCm/composable_kernel commit: 30072aec37]
This commit is contained in:
@@ -0,0 +1,292 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_CONTRACTION_V1R1_HPP
|
||||
#define CK_DRIVER_DYNAMIC_CONTRACTION_V1R1_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_contraction_v1r1.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AGKGM0GM1GridDesc,
|
||||
typename BGKGN0GN1GridDesc,
|
||||
typename CGM0GM1GN0GN1GridDesc,
|
||||
index_t GM1PerBlockGM11,
|
||||
index_t GN1PerBlockGN11,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM10,
|
||||
index_t M1N1ThreadClusterN10,
|
||||
index_t M1N1ThreadClusterM11,
|
||||
index_t M1N1ThreadClusterN11,
|
||||
typename ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
|
||||
typename ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_GM11,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
|
||||
typename BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_GN11,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
__host__ float
|
||||
driver_dynamic_contraction_v1r1(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AGKGM0GM1GridDesc& a_gk_gm0_gm1_grid_desc,
|
||||
const BGKGN0GN1GridDesc& b_gk_gn0_gn1_grid_desc,
|
||||
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
index_t nrepeat)
|
||||
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseContraction = GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AGKGM0GM1GridDesc,
|
||||
BGKGN0GN1GridDesc,
|
||||
CGM0GM1GN0GN1GridDesc,
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
|
||||
ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_GM11,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
|
||||
BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_GN11,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
const auto K = a_gk_gm0_gm1_grid_desc.GetLength(I0);
|
||||
|
||||
if(!GridwiseContraction::CheckValidity(
|
||||
a_gk_gm0_gm1_grid_desc, b_gk_gn0_gn1_grid_desc, c_gm0_gm1_gn0_gn1_grid_desc))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseDynamicContraction_km_kn0n1_mn0n1_v1r1 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_gk_gm0_gm10_gm11_grid_desc =
|
||||
GridwiseContraction::MakeAGKGM0GM10GM11GridDescriptor(a_gk_gm0_gm1_grid_desc);
|
||||
const auto b_gk_gn0_gn10_gn11_grid_desc =
|
||||
GridwiseContraction::MakeBGKGN0GN10GN11GridDescriptor(b_gk_gn0_gn1_grid_desc);
|
||||
|
||||
using AGKGM0GM10GM11GridDesc = decltype(a_gk_gm0_gm10_gm11_grid_desc);
|
||||
using BGKGN0GN10GN11GridDesc = decltype(b_gk_gn0_gn10_gn11_grid_desc);
|
||||
|
||||
// c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc
|
||||
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc =
|
||||
GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc);
|
||||
|
||||
using CGM10BM0BM1GN10BN0BN1GridDesc = decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc);
|
||||
|
||||
// c_blockid_to_gm10_gn10_block_cluster_adaptor
|
||||
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor =
|
||||
GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(c_gm0_gm1_gn0_gn1_grid_desc);
|
||||
|
||||
using CBlockIdToGM10GN10BlockClusterAdaptor =
|
||||
decltype(c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_gm0_gm1_gn0_gn1_grid_desc);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(K);
|
||||
|
||||
const bool has_double_tail_k_block_loop =
|
||||
GridwiseContraction::CalculateHasDoubleTailKBlockLoop(K);
|
||||
|
||||
{
|
||||
std::cout << "a_gk_gm0_gm10_gm11_grid_desc{" << a_gk_gm0_gm10_gm11_grid_desc.GetLength(I0)
|
||||
<< ", " << a_gk_gm0_gm10_gm11_grid_desc.GetLength(I1) << ", "
|
||||
<< a_gk_gm0_gm10_gm11_grid_desc.GetLength(I2) << ", "
|
||||
<< a_gk_gm0_gm10_gm11_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "b_gk_gn0_gn10_gn11_grid_desc{" << b_gk_gn0_gn10_gn11_grid_desc.GetLength(I0)
|
||||
<< ", " << b_gk_gn0_gn10_gn11_grid_desc.GetLength(I1) << ", "
|
||||
<< b_gk_gn0_gn10_gn11_grid_desc.GetLength(I2) << ", "
|
||||
<< b_gk_gn0_gn10_gn11_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc{ "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I0) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I1) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I2) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I3) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I4) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r1<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGKGM0GM10GM11GridDesc>,
|
||||
remove_reference_t<BGKGN0GN10GN11GridDesc>,
|
||||
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
|
||||
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_gk_gm0_gm10_gm11_grid_desc,
|
||||
b_gk_gn0_gn10_gn11_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r1<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGKGM0GM10GM11GridDesc>,
|
||||
remove_reference_t<BGKGN0GN10GN11GridDesc>,
|
||||
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
|
||||
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_gk_gm0_gm10_gm11_grid_desc,
|
||||
b_gk_gn0_gn10_gn11_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r1<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGKGM0GM10GM11GridDesc>,
|
||||
remove_reference_t<BGKGN0GN10GN11GridDesc>,
|
||||
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
|
||||
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_gk_gm0_gm10_gm11_grid_desc,
|
||||
b_gk_gn0_gn10_gn11_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r1<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGKGM0GM10GM11GridDesc>,
|
||||
remove_reference_t<BGKGN0GN10GN11GridDesc>,
|
||||
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
|
||||
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_gk_gm0_gm10_gm11_grid_desc,
|
||||
b_gk_gn0_gn10_gn11_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,388 +0,0 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_GEMM_V1
|
||||
#define CK_DRIVER_DYNAMIC_GEMM_V1
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
typename CBlockClusterDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerThread,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t MLevel0Cluster,
|
||||
index_t NLevel0Cluster,
|
||||
index_t MLevel1Cluster,
|
||||
index_t NLevel1Cluster,
|
||||
typename ABlockTransferThreadSliceLengths_K_M,
|
||||
typename ABlockTransferThreadClusterLengths_K_M,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_M,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K_N,
|
||||
typename BBlockTransferThreadClusterLengths_K_N,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_N,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGlobalIteratorHacks,
|
||||
typename BGlobalIteratorHacks,
|
||||
typename CGlobalIteratorHacks,
|
||||
typename AGlobalMoveSliceWindowIteratorHacks,
|
||||
typename BGlobalMoveSliceWindowIteratorHacks>
|
||||
__host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
|
||||
const FloatAB* p_b_global,
|
||||
FloatC* p_c_global,
|
||||
const AGlobalDesc& a_k_m_global_desc,
|
||||
const BGlobalDesc& b_k_n_global_desc,
|
||||
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
|
||||
const CBlockClusterDesc& c_block_cluster_desc,
|
||||
AGlobalIteratorHacks,
|
||||
BGlobalIteratorHacks,
|
||||
CGlobalIteratorHacks,
|
||||
AGlobalMoveSliceWindowIteratorHacks,
|
||||
BGlobalMoveSliceWindowIteratorHacks,
|
||||
index_t nrepeat)
|
||||
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto M = a_k_m_global_desc.GetLength(I1);
|
||||
const auto N = b_k_n_global_desc.GetLength(I1);
|
||||
const auto K = a_k_m_global_desc.GetLength(I0);
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{};
|
||||
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{};
|
||||
|
||||
if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// GEMM
|
||||
using gridwise_gemm =
|
||||
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AGlobalDesc,
|
||||
BGlobalDesc,
|
||||
CGlobalDesc,
|
||||
CBlockClusterDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
MLevel0Cluster,
|
||||
NLevel0Cluster,
|
||||
MLevel1Cluster,
|
||||
NLevel1Cluster,
|
||||
ABlockTransferThreadSliceLengths_K_M,
|
||||
ABlockTransferThreadClusterLengths_K_M,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N,
|
||||
BBlockTransferThreadClusterLengths_K_N,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGlobalIteratorHacks,
|
||||
BGlobalIteratorHacks,
|
||||
CGlobalIteratorHacks,
|
||||
AGlobalMoveSliceWindowIteratorHacks,
|
||||
BGlobalMoveSliceWindowIteratorHacks>;
|
||||
|
||||
const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
|
||||
|
||||
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
|
||||
|
||||
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc));
|
||||
DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc));
|
||||
DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc));
|
||||
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
|
||||
|
||||
a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc);
|
||||
b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc);
|
||||
c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc);
|
||||
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
387
composable_kernel/include/driver/driver_dynamic_gemm_v1r1.hpp
Normal file
387
composable_kernel/include/driver/driver_dynamic_gemm_v1r1.hpp
Normal file
@@ -0,0 +1,387 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_GEMM_V1
|
||||
#define CK_DRIVER_DYNAMIC_GEMM_V1
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_v1r1.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
typename CBlockClusterDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM10,
|
||||
index_t M1N1ThreadClusterN10,
|
||||
index_t M1N1ThreadClusterM11,
|
||||
index_t M1N1ThreadClusterN11,
|
||||
typename ABlockTransferThreadSliceLengths_K_M,
|
||||
typename ABlockTransferThreadClusterLengths_K_M,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_M,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K_N,
|
||||
typename BBlockTransferThreadClusterLengths_K_N,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_N,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGlobalIteratorHacks,
|
||||
typename BGlobalIteratorHacks,
|
||||
typename CGlobalIteratorHacks,
|
||||
typename AGlobalMoveSliceWindowIteratorHacks,
|
||||
typename BGlobalMoveSliceWindowIteratorHacks>
|
||||
__host__ float launch_kernel_dynamic_gemm_v1r1(const FloatAB* p_a_global,
|
||||
const FloatAB* p_b_global,
|
||||
FloatC* p_c_global,
|
||||
const AGlobalDesc& a_k_m_global_desc,
|
||||
const BGlobalDesc& b_k_n_global_desc,
|
||||
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
|
||||
const CBlockClusterDesc& c_block_cluster_desc,
|
||||
AGlobalIteratorHacks,
|
||||
BGlobalIteratorHacks,
|
||||
CGlobalIteratorHacks,
|
||||
AGlobalMoveSliceWindowIteratorHacks,
|
||||
BGlobalMoveSliceWindowIteratorHacks,
|
||||
index_t nrepeat)
|
||||
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto M = a_k_m_global_desc.GetLength(I1);
|
||||
const auto N = b_k_n_global_desc.GetLength(I1);
|
||||
const auto K = a_k_m_global_desc.GetLength(I0);
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
constexpr auto M1 = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{};
|
||||
constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{};
|
||||
|
||||
if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// GEMM
|
||||
using gridwise_gemm =
|
||||
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AGlobalDesc,
|
||||
BGlobalDesc,
|
||||
CGlobalDesc,
|
||||
CBlockClusterDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M,
|
||||
ABlockTransferThreadClusterLengths_K_M,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N,
|
||||
BBlockTransferThreadClusterLengths_K_N,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGlobalIteratorHacks,
|
||||
BGlobalIteratorHacks,
|
||||
CGlobalIteratorHacks,
|
||||
AGlobalMoveSliceWindowIteratorHacks,
|
||||
BGlobalMoveSliceWindowIteratorHacks>;
|
||||
|
||||
const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
|
||||
|
||||
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
|
||||
|
||||
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc));
|
||||
DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc));
|
||||
DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc));
|
||||
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
|
||||
|
||||
a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc);
|
||||
b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc);
|
||||
c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc);
|
||||
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
285
composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
Normal file
285
composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
Normal file
@@ -0,0 +1,285 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_GEMM_V1R2
|
||||
#define CK_DRIVER_DYNAMIC_GEMM_V1R2
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_v1r2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AKMGridDesc,
|
||||
typename BKNGridDesc,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM10,
|
||||
index_t M1N1ThreadClusterN10,
|
||||
index_t M1N1ThreadClusterM11,
|
||||
index_t M1N1ThreadClusterN11,
|
||||
typename ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_M1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_N1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
__host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AKMGridDesc& a_k_m_grid_desc,
|
||||
const BKNGridDesc& b_k_n_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
index_t nrepeat)
|
||||
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemm_km_kn_mn_v1r2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AKMGridDesc,
|
||||
BKNGridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseDynamicGemm_km_kn_mn_v1r2 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||
const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||
|
||||
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc);
|
||||
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc);
|
||||
|
||||
// c_m0_m10_m11_n0_n10_n11_grid_desc
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
|
||||
// c_blockid_to_m0_n0_block_cluster_adaptor
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
|
||||
|
||||
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K);
|
||||
|
||||
{
|
||||
std::cout << "a_k_m0_m1_grid_desc{" << a_k_m0_m1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_k_m0_m1_grid_desc.GetLength(I1) << ", " << a_k_m0_m1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "b_k_n0_n1_grid_desc{" << b_k_n0_n1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_k_n0_n1_grid_desc.GetLength(I1) << ", " << b_k_n0_n1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -10,11 +10,7 @@ namespace ck {
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmM1,
|
||||
index_t GemmN1,
|
||||
typename... Wei,
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
@@ -100,74 +96,11 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
|
||||
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
|
||||
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
|
||||
|
||||
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
|
||||
|
||||
const auto GemmM0 = GemmM / Number<GemmM1>{};
|
||||
const auto GemmN0 = GemmN / Number<GemmN1>{};
|
||||
|
||||
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor(
|
||||
out_gemmm_gemmn_global_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
|
||||
make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
// out_gemm_block_cluster_desc
|
||||
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
|
||||
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
|
||||
|
||||
// hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor
|
||||
constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over in_gemmk_gemmn_global tensor
|
||||
constexpr auto in_gemmk_gemmn_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
|
||||
|
||||
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
|
||||
// tensor hack for NKHW format
|
||||
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
return make_tuple(wei_gemmk_gemmm_global_desc,
|
||||
in_gemmk_gemmn_global_desc,
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
|
||||
out_gemm_block_cluster_desc,
|
||||
wei_gemmk_gemmm_global_iterator_hacks,
|
||||
in_gemmk_gemmn_global_iterator_hacks,
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
|
||||
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
|
||||
in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
|
||||
return make_tuple(
|
||||
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc);
|
||||
}
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmM1,
|
||||
index_t GemmN1,
|
||||
typename... Wei,
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
@@ -247,72 +180,11 @@ transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
|
||||
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
|
||||
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
|
||||
|
||||
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
|
||||
|
||||
const auto GemmM0 = GemmM / Number<GemmM1>{};
|
||||
const auto GemmN0 = GemmN / Number<GemmN1>{};
|
||||
|
||||
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor(
|
||||
out_gemmm_gemmn_global_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
|
||||
make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
// out_gemm_block_cluster_desc
|
||||
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
|
||||
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
|
||||
|
||||
// hack to control index calculation when iterating over a_k_m_global tensor
|
||||
constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over b_k_n_global tensor
|
||||
constexpr auto in_gemmk_gemmn_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 1, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 1>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 2, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 2>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 1, 2>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||
// hack for NKHW format
|
||||
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
return make_tuple(wei_gemmk_gemmm_global_desc,
|
||||
in_gemmk_gemmn_global_desc,
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
|
||||
out_gemm_block_cluster_desc,
|
||||
wei_gemmk_gemmm_global_iterator_hacks,
|
||||
in_gemmk_gemmn_global_iterator_hacks,
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
|
||||
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
|
||||
in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
|
||||
return make_tuple(
|
||||
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc);
|
||||
}
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmM1,
|
||||
index_t GemmN1,
|
||||
typename... Wei,
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
@@ -383,60 +255,8 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
|
||||
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
|
||||
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
|
||||
|
||||
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
|
||||
|
||||
const auto GemmM0 = GemmM / Number<GemmM1>{};
|
||||
const auto GemmN0 = GemmN / Number<GemmN1>{};
|
||||
|
||||
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor(
|
||||
out_gemmm_gemmn_global_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
|
||||
make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
// out_gemm_block_cluster_desc
|
||||
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
|
||||
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
|
||||
|
||||
// hack to control index calculation when iterating over a_k_m_global tensor
|
||||
constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over b_k_n_global tensor
|
||||
constexpr auto in_gemmk_gemmn_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0>{}, Sequence<0, 0, 1>{}),
|
||||
make_tuple(Sequence<0, 2, 0>{}, Sequence<0, 0, 2>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = Sequence<0, 1, 2>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
return make_tuple(wei_gemmk_gemmm_global_desc,
|
||||
in_gemmk_gemmn_global_desc,
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
|
||||
out_gemm_block_cluster_desc,
|
||||
wei_gemmk_gemmm_global_iterator_hacks,
|
||||
in_gemmk_gemmn_global_iterator_hacks,
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
|
||||
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
|
||||
in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
|
||||
return make_tuple(
|
||||
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -10,11 +10,7 @@ namespace ck {
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmM1,
|
||||
index_t GemmN1,
|
||||
typename... Wei,
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
@@ -100,74 +96,11 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
|
||||
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
|
||||
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
|
||||
|
||||
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
|
||||
|
||||
const auto GemmM0 = GemmM / Number<GemmM1>{};
|
||||
const auto GemmN0 = GemmN / Number<GemmN1>{};
|
||||
|
||||
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor(
|
||||
out_gemmm_gemmn_global_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
|
||||
make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
// out_gemm_block_cluster_desc
|
||||
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
|
||||
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
|
||||
|
||||
// hack to control index calculation when iterating over a_k_m_global tensor
|
||||
constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over b_k_n_global tensor
|
||||
constexpr auto in_gemmk_gemmn_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||
// hack for NKHW format
|
||||
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
return make_tuple(wei_gemmk_gemmm_global_desc,
|
||||
in_gemmk_gemmn_global_desc,
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
|
||||
out_gemm_block_cluster_desc,
|
||||
wei_gemmk_gemmm_global_iterator_hacks,
|
||||
in_gemmk_gemmn_global_iterator_hacks,
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
|
||||
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
|
||||
in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
|
||||
return make_tuple(
|
||||
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc);
|
||||
}
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmM1,
|
||||
index_t GemmN1,
|
||||
typename... Wei,
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
@@ -238,61 +171,8 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
|
||||
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
|
||||
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
|
||||
|
||||
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
|
||||
|
||||
const auto GemmM0 = GemmM / Number<GemmM1>{};
|
||||
const auto GemmN0 = GemmN / Number<GemmN1>{};
|
||||
|
||||
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor(
|
||||
out_gemmm_gemmn_global_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
|
||||
make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
// out_gemm_block_cluster_desc
|
||||
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
|
||||
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
|
||||
|
||||
// hack to control index calculation when iterating over wei_gemmk_gemmm_global_iterator_hacks
|
||||
// tensor
|
||||
constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over b_k_n_global tensor
|
||||
constexpr auto in_gemmk_gemmn_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
return make_tuple(wei_gemmk_gemmm_global_desc,
|
||||
in_gemmk_gemmn_global_desc,
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
|
||||
out_gemm_block_cluster_desc,
|
||||
wei_gemmk_gemmm_global_iterator_hacks,
|
||||
in_gemmk_gemmn_global_iterator_hacks,
|
||||
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
|
||||
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
|
||||
in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
|
||||
return make_tuple(
|
||||
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <index_t N0_,
|
||||
typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad(
|
||||
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gk_gm0_gm1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(I1, K)),
|
||||
make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1, 2>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_c_hi_wi_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto N0 = Number<N0_>{};
|
||||
const auto N1 = N / N0;
|
||||
|
||||
const auto in_n0_n1_c_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_c_hip_wip_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4>{}, Sequence<5, 6>{}));
|
||||
|
||||
const auto in_gk_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n0_n1_c_y_ho_x_wo_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_pass_through_transform(N0),
|
||||
make_merge_transform(make_tuple(N1, Ho, Wo))),
|
||||
make_tuple(Sequence<2, 3, 5>{}, Sequence<0>{}, Sequence<1, 4, 6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_n_k_howo_grid_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo));
|
||||
|
||||
const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
out_n_k_howo_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<N0>{}, N1)),
|
||||
make_unmerge_transform(make_tuple(I1, K)),
|
||||
make_pass_through_transform(Ho * Wo)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||
|
||||
const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
out_n0_n1_1_k_howo_grid_desc,
|
||||
make_tuple(make_pass_through_transform(I1),
|
||||
make_pass_through_transform(K),
|
||||
make_pass_through_transform(Number<N0>{}),
|
||||
make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))),
|
||||
make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
return make_tuple(
|
||||
wei_gk_gm0_gm1_grid_desc, in_gk_gn0_gn1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1164,6 +1164,165 @@ struct DynamicMerge_v2_magic_division
|
||||
}
|
||||
};
|
||||
|
||||
// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering
|
||||
// of both multi-index and delta of multi-index
|
||||
// Caution:
|
||||
// 1. The magic number division implementation being used would produce correct result if the
|
||||
// dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
|
||||
// 2. The magic number division for int32_t dividened has not been implemented, the int32_t
|
||||
// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
|
||||
// uint32_t is then used.
|
||||
// 3. For Merge primitive, upper-index is the dividend.
|
||||
// 4. When upper-index is uint32_t, its value need to be within 31-bit range.
|
||||
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
|
||||
// non-negative.
|
||||
template <typename LowLengths>
|
||||
struct DynamicMerge_v2r2_magic_division
|
||||
{
|
||||
static constexpr index_t NDimLow = LowLengths::Size();
|
||||
|
||||
using LowerIndex = MultiIndex<NDimLow>;
|
||||
using UpperIndex = MultiIndex<1>;
|
||||
|
||||
using LowLengthsScan = decltype(
|
||||
container_reverse_exclusive_scan(LowLengths{}, math::multiplies_v2{}, Number<1>{}));
|
||||
|
||||
using UpLengths =
|
||||
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies_v2{}, Number<1>{})));
|
||||
|
||||
using LowLengthsScanMagicDivisorMultipiler = decltype(generate_tuple(
|
||||
lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengthsScan>{},
|
||||
Number<NDimLow>{}));
|
||||
|
||||
using LowLengthsScanMagicDivisorShift = decltype(
|
||||
generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift<LowLengthsScan>{},
|
||||
Number<NDimLow>{}));
|
||||
|
||||
LowLengths low_lengths_;
|
||||
LowLengthsScan low_lengths_scan_;
|
||||
LowLengthsScanMagicDivisorMultipiler low_lengths_scan_magic_divisor_multiplier_;
|
||||
LowLengthsScanMagicDivisorShift low_lengths_scan_magic_divisor_shift_;
|
||||
UpLengths up_lengths_;
|
||||
|
||||
__host__ __device__ constexpr DynamicMerge_v2r2_magic_division() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicMerge_v2r2_magic_division(const LowLengths& low_lengths)
|
||||
: low_lengths_{low_lengths},
|
||||
low_lengths_scan_{
|
||||
container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})},
|
||||
low_lengths_scan_magic_divisor_multiplier_{generate_tuple(
|
||||
[&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths_scan_[i]); },
|
||||
Number<NDimLow>{})},
|
||||
low_lengths_scan_magic_divisor_shift_{generate_tuple(
|
||||
[&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths_scan_[i]); },
|
||||
Number<NDimLow>{})},
|
||||
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies_v2{}, Number<1>{}))}
|
||||
{
|
||||
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
|
||||
|
||||
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
|
||||
const UpIdx& idx_up) const
|
||||
{
|
||||
static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
index_t tmp = idx_up[Number<0>{}];
|
||||
|
||||
static_for<0, NDimLow - 1, 1>{}([&, this](auto i) {
|
||||
idx_low(i) =
|
||||
MagicDivision::DoMagicDivision(tmp,
|
||||
this->low_lengths_scan_magic_divisor_multiplier_[i],
|
||||
this->low_lengths_scan_magic_divisor_shift_[i]);
|
||||
|
||||
tmp -= idx_low[i] * this->low_lengths_scan_[i];
|
||||
});
|
||||
|
||||
idx_low(Number<NDimLow - 1>{}) = tmp;
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff,
|
||||
typename UpIdxDiff,
|
||||
typename LowIdx,
|
||||
typename UpIdx,
|
||||
index_t Hack>
|
||||
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up_new,
|
||||
Number<Hack>) const
|
||||
{
|
||||
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
|
||||
LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
index_t tmp = idx_up_new[Number<0>{}];
|
||||
|
||||
static_for<0, NDimLow - 1, 1>{}([&, this](auto i) {
|
||||
index_t idx_low_old = idx_low[i];
|
||||
|
||||
idx_low(i) =
|
||||
MagicDivision::DoMagicDivision(tmp,
|
||||
this->low_lengths_scan_magic_divisor_multiplier_[i],
|
||||
this->low_lengths_scan_magic_divisor_shift_[i]);
|
||||
|
||||
idx_diff_low(i) = idx_low[i] - idx_low_old;
|
||||
|
||||
tmp -= idx_low[i] * this->low_lengths_scan_[i];
|
||||
});
|
||||
|
||||
idx_diff_low(Number<NDimLow - 1>{}) = tmp - idx_low[Number<NDimLow - 1>{}];
|
||||
|
||||
idx_low(Number<NDimLow - 1>{}) = tmp;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
|
||||
{
|
||||
return is_known_at_compile_time<LowLengths>::value &&
|
||||
is_known_at_compile_time<LowLengthsScanMagicDivisorMultipiler>::value &&
|
||||
is_known_at_compile_time<LowLengthsScanMagicDivisorShift>::value &&
|
||||
is_known_at_compile_time<UpLengths>::value;
|
||||
}
|
||||
|
||||
template <typename UpIdx>
|
||||
__host__ __device__ static constexpr bool
|
||||
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("DynamicMerge_v2r2_magic_division, ");
|
||||
printf("low_lengths_ ");
|
||||
print_multi_index(low_lengths_);
|
||||
printf("low_lengths_scan ");
|
||||
print_multi_index(low_lengths_scan_);
|
||||
printf("low_lengths_scan_magic_divisor_multiplier_ ");
|
||||
print_multi_index(low_lengths_scan_magic_divisor_multiplier_);
|
||||
printf("low_lengths_scan_magic_divisor_shift_ ");
|
||||
print_multi_index(low_lengths_scan_magic_divisor_shift_);
|
||||
printf("up_lengths_ ");
|
||||
print_multi_index(up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UpLengths, bool Use24BitIntegerCalculation>
|
||||
struct DynamicUnMerge
|
||||
{
|
||||
|
||||
@@ -56,8 +56,19 @@ __host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_le
|
||||
#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
|
||||
return DynamicMerge_v1_carry_check<LowLengths>{low_lengths};
|
||||
#else
|
||||
#if 1
|
||||
return DynamicMerge_v2_magic_division<LowLengths>{low_lengths};
|
||||
#else
|
||||
return DynamicMerge_v2r2_magic_division<LowLengths>{low_lengths};
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename LowLengths>
|
||||
__host__ __device__ constexpr auto
|
||||
make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
|
||||
{
|
||||
return DynamicMerge_v2_magic_division<LowLengths>{low_lengths};
|
||||
}
|
||||
|
||||
template <typename UpLengths, bool Use24BitIntegerCalculation = false>
|
||||
|
||||
@@ -308,6 +308,19 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
NewLowerDimensionOldVisibleIdss,
|
||||
NewUpperDimensionNewVisibleIdss)
|
||||
{
|
||||
// sanity check
|
||||
{
|
||||
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
|
||||
NewLowerDimensionOldVisibleIdss{});
|
||||
|
||||
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
|
||||
NewUpperDimensionNewVisibleIdss{});
|
||||
|
||||
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
|
||||
is_valid_sequence_map<decltype(all_old_top_ids)>::value,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
// lower dimension's hidden idss
|
||||
// convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of
|
||||
// sequences)
|
||||
|
||||
@@ -115,26 +115,30 @@ template <typename... Lengths, typename Align>
|
||||
__host__ __device__ constexpr auto
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths, Align align)
|
||||
{
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
const auto stride_n_minus_2 = math::integer_least_multiple(lengths[Number<N - 1>{}], align);
|
||||
|
||||
auto strides = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i.value == N - 1)
|
||||
{
|
||||
return Number<1>{};
|
||||
return I1;
|
||||
}
|
||||
else if constexpr(i.value == N - 2)
|
||||
{
|
||||
return math::lcm(lengths[Number<N - 1>{}], align);
|
||||
return Number<stride_n_minus_2>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return container_reduce(lengths,
|
||||
math::multiplies_v2{},
|
||||
math::lcm(lengths[Number<N - 1>{}], align),
|
||||
i,
|
||||
Number<stride_n_minus_2>{},
|
||||
i + I1,
|
||||
Number<N - 2>{},
|
||||
Number<1>{});
|
||||
I1);
|
||||
}
|
||||
},
|
||||
Number<N>{});
|
||||
|
||||
@@ -31,8 +31,8 @@ template <index_t BlockSize,
|
||||
index_t DstScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
index_t DstScalarStrideInVector,
|
||||
index_t ThreadTransferSrcResetCoordinateAfterRun,
|
||||
index_t ThreadTransferDstResetCoordinateAfterRun>
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
struct BlockwiseDynamicTensorSliceTransfer_v4
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
|
||||
@@ -29,10 +29,10 @@ template <index_t BlockSize,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
index_t MLevel0ThreadCluster,
|
||||
index_t NLevel0ThreadCluster,
|
||||
index_t MLevel1ThreadCluster,
|
||||
index_t NLevel1ThreadCluster,
|
||||
index_t M1N1ThreadClusterM10,
|
||||
index_t M1N1ThreadClusterN10,
|
||||
index_t M1N1ThreadClusterM11,
|
||||
index_t M1N1ThreadClusterN11,
|
||||
index_t AThreadCopyScalarPerVector_M1,
|
||||
index_t BThreadCopyScalarPerVector_N1,
|
||||
typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() &&
|
||||
@@ -62,8 +62,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
|
||||
CThreadDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(BlockSize == MLevel0ThreadCluster * MLevel1ThreadCluster *
|
||||
NLevel0ThreadCluster * NLevel1ThreadCluster,
|
||||
static_assert(BlockSize == M1N1ThreadClusterM11 * M1N1ThreadClusterM10 *
|
||||
M1N1ThreadClusterN11 * M1N1ThreadClusterN10,
|
||||
"wrong! blocksize and cluster size not consistent");
|
||||
|
||||
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
|
||||
@@ -78,6 +78,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
|
||||
constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
|
||||
|
||||
// 4-d data space into 4-d thread space
|
||||
// upper: {1, M1N1ThreadClusterM10 * M1N1ThreadClusterM11, 1, M1N1ThreadClusterN10 *
|
||||
// M1N1ThreadClusterN11} lower: {M0, M1, N0, N1}
|
||||
constexpr auto adaptor0 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_vectorize_transform(M0, 1),
|
||||
make_vectorize_transform(M1PerThread, M1 / M1PerThread),
|
||||
@@ -87,21 +89,27 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// thread position 4-d thread space
|
||||
// upper: {M1N1ThreadClusterM10, M1N1ThreadClusterM11, M1N1ThreadClusterN10,
|
||||
// M1N1ThreadClusterN11} lower: {1, M1N1ThreadClusterM10 * M1N1ThreadClusterM11, 1,
|
||||
// M1N1ThreadClusterN10 * M1N1ThreadClusterN11}
|
||||
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(
|
||||
make_freeze_transform(make_multi_index(0)),
|
||||
make_unmerge_transform(make_tuple(MLevel1ThreadCluster, MLevel0ThreadCluster)),
|
||||
make_unmerge_transform(make_tuple(M1N1ThreadClusterM10, M1N1ThreadClusterM11)),
|
||||
make_freeze_transform(make_multi_index(0)),
|
||||
make_unmerge_transform(make_tuple(NLevel1ThreadCluster, NLevel0ThreadCluster))),
|
||||
make_unmerge_transform(make_tuple(M1N1ThreadClusterN10, M1N1ThreadClusterN11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{}));
|
||||
|
||||
// 4-d thread space to 1-d thread space
|
||||
// upper: {BlockSize}
|
||||
// lower: {M1N1ThreadClusterM10, M1N1ThreadClusterM11, M1N1ThreadClusterN10,
|
||||
// M1N1ThreadClusterN11}
|
||||
constexpr auto adaptor2 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MLevel1ThreadCluster,
|
||||
NLevel1ThreadCluster,
|
||||
MLevel0ThreadCluster,
|
||||
NLevel0ThreadCluster))),
|
||||
make_tuple(make_merge_transform(make_tuple(M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11))),
|
||||
make_tuple(Sequence<0, 2, 1, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
@@ -221,10 +229,10 @@ template <index_t BlockSize,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
index_t MLevel0ThreadCluster,
|
||||
index_t NLevel0ThreadCluster,
|
||||
index_t MLevel1ThreadCluster,
|
||||
index_t NLevel1ThreadCluster,
|
||||
index_t M1N1ThreadClusterM10,
|
||||
index_t M1N1ThreadClusterN10,
|
||||
index_t M1N1ThreadClusterM11,
|
||||
index_t M1N1ThreadClusterN11,
|
||||
index_t AThreadCopyScalarPerVector_M1,
|
||||
index_t BThreadCopyScalarPerVector_N1,
|
||||
typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() &&
|
||||
@@ -254,8 +262,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
|
||||
CThreadDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(BlockSize == MLevel0ThreadCluster * MLevel1ThreadCluster *
|
||||
NLevel0ThreadCluster * NLevel1ThreadCluster,
|
||||
static_assert(BlockSize == M1N1ThreadClusterM11 * M1N1ThreadClusterM10 *
|
||||
M1N1ThreadClusterN11 * M1N1ThreadClusterN10,
|
||||
"wrong! blocksize and cluster size not consistent");
|
||||
|
||||
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
|
||||
@@ -287,18 +295,18 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
|
||||
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(
|
||||
make_freeze_transform(make_multi_index(0)),
|
||||
make_unmerge_transform(make_tuple(MLevel1ThreadCluster, MLevel0ThreadCluster)),
|
||||
make_unmerge_transform(make_tuple(M1N1ThreadClusterM10, M1N1ThreadClusterM11)),
|
||||
make_freeze_transform(make_multi_index(0)),
|
||||
make_unmerge_transform(make_tuple(NLevel1ThreadCluster, NLevel0ThreadCluster))),
|
||||
make_unmerge_transform(make_tuple(M1N1ThreadClusterN10, M1N1ThreadClusterN11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{}));
|
||||
|
||||
// 4-d thread space to 1-d thread space
|
||||
constexpr auto adaptor2 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MLevel1ThreadCluster,
|
||||
NLevel1ThreadCluster,
|
||||
MLevel0ThreadCluster,
|
||||
NLevel0ThreadCluster))),
|
||||
make_tuple(make_merge_transform(make_tuple(M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11))),
|
||||
make_tuple(Sequence<0, 2, 1, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
|
||||
@@ -0,0 +1,396 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_V2R2_HPP
|
||||
#define CK_BLOCKWISE_GEMM_V2R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_gemm_v2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
|
||||
// A and B are visable to the whole block, C is distributed among each thread
|
||||
// Assume:
|
||||
// 1. A:
|
||||
// 1. AKMBlockDesc is known at compile-time
|
||||
// 2. ABlockBuffer is DynamicBuffer
|
||||
// 2. B:
|
||||
// 1. BKNBlockDesc is known at compile-time
|
||||
// 2. BBlockBuffer is DynamicBuffer
|
||||
// 3. C:
|
||||
// 1. CM0M1N0N1ThreadDesc is known at compile-time
|
||||
// 2. CThreadBuffer is StaticBuffer
|
||||
// Also assume:
|
||||
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AKMBlockDesc,
|
||||
typename BKNBlockDesc,
|
||||
index_t M1PerThreadM11,
|
||||
index_t N1PerThreadN11,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM100,
|
||||
index_t M1N1ThreadClusterN100,
|
||||
index_t M1N1ThreadClusterM101,
|
||||
index_t M1N1ThreadClusterN101,
|
||||
index_t AThreadCopyScalarPerVector_M11,
|
||||
index_t BThreadCopyScalarPerVector_N11,
|
||||
typename std::enable_if<AKMBlockDesc::IsKnownAtCompileTime() &&
|
||||
BKNBlockDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
||||
{
|
||||
using AIndex = MultiIndex<3>;
|
||||
using BIndex = MultiIndex<3>;
|
||||
using CIndex = MultiIndex<4>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr index_t K = AKMBlockDesc{}.GetLength(I0);
|
||||
static constexpr index_t M = AKMBlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t N = BKNBlockDesc{}.GetLength(I1);
|
||||
|
||||
static constexpr index_t M100 = M1N1ThreadClusterM100;
|
||||
static constexpr index_t N100 = M1N1ThreadClusterN100;
|
||||
|
||||
static constexpr index_t M101 = M1N1ThreadClusterM101;
|
||||
static constexpr index_t N101 = M1N1ThreadClusterN101;
|
||||
|
||||
static constexpr index_t M11 = M1PerThreadM11;
|
||||
static constexpr index_t N11 = N1PerThreadN11;
|
||||
|
||||
static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11;
|
||||
static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11;
|
||||
|
||||
static constexpr index_t M0 = M / M1;
|
||||
static constexpr index_t N0 = N / N1;
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAKM0M1BlockDescriptor(const AKMBlockDesc& a_k_m_block_desc)
|
||||
{
|
||||
const auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
AKMBlockDesc{},
|
||||
make_tuple(make_pass_through_transform(Number<K>{}),
|
||||
make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
return a_k_m0_m1_block_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBKN0N1BlockDescriptor(const BKNBlockDesc& b_k_n_block_desc)
|
||||
{
|
||||
const auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
BKNBlockDesc{},
|
||||
make_tuple(make_pass_through_transform(Number<K>{}),
|
||||
make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
return b_k_n0_n1_block_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor()
|
||||
{
|
||||
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
// lower: [M, N]
|
||||
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
Number<M0>{}, Number<M100>{}, Number<M101>{}, Number<M11>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<N0>{}, Number<N100>{}, Number<N101>{}, Number<N11>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
|
||||
|
||||
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor()
|
||||
{
|
||||
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
// lower: [M0, M1, N0, N1]
|
||||
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_pass_through_transform(Number<M0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<M100>{}, Number<M101>{}, Number<M11>{})),
|
||||
make_pass_through_transform(Number<N0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<N100>{}, Number<N101>{}, Number<N11>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
|
||||
|
||||
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths()
|
||||
{
|
||||
return Sequence<M0, M11, N0, N11>{};
|
||||
}
|
||||
|
||||
static constexpr auto a_k_m0_m1_block_desc_ = MakeAKM0M1BlockDescriptor(AKMBlockDesc{});
|
||||
static constexpr auto b_k_n0_n1_block_desc_ = MakeBKN0N1BlockDescriptor(BKNBlockDesc{});
|
||||
|
||||
public:
|
||||
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2()
|
||||
: c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock(
|
||||
get_thread_local_1d_id())},
|
||||
a_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
|
||||
b_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])}
|
||||
{
|
||||
static_assert(AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(BlockSize == M101 * M100 * N101 * N100,
|
||||
"wrong! blocksize and cluster size not consistent");
|
||||
|
||||
static_assert(M % M1 == 0 && N % N1 == 0, "wrong!");
|
||||
|
||||
static_assert(AKMBlockDesc{}.GetLength(I0) == BKNBlockDesc{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent");
|
||||
|
||||
// TODO: remove this restriction
|
||||
static_assert(M0 == 2 && N0 == 2, "wrong");
|
||||
}
|
||||
|
||||
__device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id)
|
||||
{
|
||||
// lower: [M0, M1, N0, N1]
|
||||
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor();
|
||||
|
||||
// lower: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
// upper: [Tid, M0, M11, N0, N11]
|
||||
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)),
|
||||
make_pass_through_transform(M0),
|
||||
make_pass_through_transform(M11),
|
||||
make_pass_through_transform(N0),
|
||||
make_pass_through_transform(N11)),
|
||||
make_tuple(
|
||||
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
|
||||
|
||||
return adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id(), 0, 0, 0, 0));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetABlockAlignment() { return M1PerThreadM11; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetBBlockAlignment() { return N1PerThreadN11; }
|
||||
|
||||
template <typename CM0M1N0N1ThreadDesc,
|
||||
typename ABlockBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const CM0M1N0N1ThreadDesc& c_m0_m1_n0_n1_thread_desc,
|
||||
const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
// TODO: remove this restriction
|
||||
static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 &&
|
||||
CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0,
|
||||
"wrong");
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatA>(
|
||||
a_k_m0_m1_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatB>(
|
||||
b_k_n0_n1_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_gemm =
|
||||
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_k_m0_m1_thread_desc_),
|
||||
decltype(b_k_n0_n1_thread_desc_),
|
||||
CM0M1N0N1ThreadDesc,
|
||||
Sequence<KPerThread>,
|
||||
Sequence<1, M1PerThreadM11>,
|
||||
Sequence<1, N1PerThreadN11>>{};
|
||||
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_k_m0_m1_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_k_n0_n1_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_block_buf,
|
||||
b_k_n0_n1_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
a_block_buf,
|
||||
a_k_m0_m1_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
|
||||
// loop over rest of k
|
||||
static_for<KPerThread, K, KPerThread>{}([&](auto k) {
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
|
||||
make_tuple(k, I0, I0),
|
||||
a_block_buf,
|
||||
a_k_m0_m1_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
|
||||
make_tuple(k, I0, I0),
|
||||
b_block_buf,
|
||||
b_k_n0_n1_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
|
||||
make_tuple(k, I1, I0),
|
||||
b_block_buf,
|
||||
b_k_n0_n1_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
|
||||
make_tuple(k, I1, I0),
|
||||
a_block_buf,
|
||||
a_k_m0_m1_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
});
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
}
|
||||
|
||||
private:
|
||||
// A[K, M0, M1]
|
||||
static constexpr auto a_k_m0_m1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{}));
|
||||
|
||||
// B[K, N0, N1]
|
||||
static constexpr auto b_k_n0_n1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{}));
|
||||
|
||||
using AThreadCopy =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
decltype(a_k_m0_m1_block_desc_),
|
||||
decltype(a_k_m0_m1_thread_desc_),
|
||||
Sequence<KPerThread, 1, M1PerThreadM11>,
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
AThreadCopyScalarPerVector_M11,
|
||||
1>;
|
||||
|
||||
using BThreadCopy =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
|
||||
FloatB,
|
||||
decltype(b_k_n0_n1_block_desc_),
|
||||
decltype(b_k_n0_n1_thread_desc_),
|
||||
Sequence<KPerThread, 1, N1PerThreadN11>,
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
BThreadCopyScalarPerVector_N11,
|
||||
1>;
|
||||
|
||||
CIndex c_thread_origin_data_idx_;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,680 @@
|
||||
#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R1_HPP
|
||||
#define CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R1_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_v2r2.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseContraction,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AGKGM0GM10GM11GridDesc,
|
||||
typename BGKGN0GN10GN11GridDesc,
|
||||
typename CGM10BM0BM1GN10BN0BN1GridDesc,
|
||||
typename CBlockIdToGM10GN10BlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_contraction_v1r1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGKGM0GM10GM11GridDesc a_gk_gm0_gm10_gm11_grid_desc,
|
||||
const BGKGN0GN10GN11GridDesc b_gk_gn0_gn10_gn11_grid_desc,
|
||||
const CGM10BM0BM1GN10BN0BN1GridDesc c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
const CBlockIdToGM10GN10BlockClusterAdaptor
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseContraction::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_gk_gm0_gm10_gm11_grid_desc,
|
||||
b_gk_gn0_gn10_gn11_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AGKGM0GM1GridDesc,
|
||||
typename BGKGN0GN1GridDesc,
|
||||
typename CGM0GM1GN0GN1GridDesc,
|
||||
index_t GM1PerBlockGM11,
|
||||
index_t GN1PerBlockGN11,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThreadM111,
|
||||
index_t N1PerThreadN111,
|
||||
index_t KPerThread,
|
||||
index_t M11N11ThreadClusterM1100,
|
||||
index_t M11N11ThreadClusterN1100,
|
||||
index_t M11N11ThreadClusterM1101,
|
||||
index_t M11N11ThreadClusterN1101,
|
||||
typename ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
|
||||
typename ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_GM11,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
|
||||
typename BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_GN11,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
struct GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto GM0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I0);
|
||||
static constexpr auto GN0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I2);
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_GM11>{},
|
||||
Number<BBlockTransferDstScalarPerVector_GN11>{},
|
||||
Number<M1PerThreadM111>{},
|
||||
Number<N1PerThreadN111>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_gk_gm0_gm10_gm11_block_desc =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_gk_gn0_gn10_gn11_block_desc =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}), max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGKGM0GM1GridDesc& a_gk_gm0_gm1_grid_desc,
|
||||
const BGKGN0GN1GridDesc& b_gk_gn0_gn1_grid_desc,
|
||||
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(GM0)>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<decltype(GN0)>>::value,
|
||||
"wrong!");
|
||||
|
||||
const auto GM1 = a_gk_gm0_gm1_grid_desc.GetLength(I2);
|
||||
const auto GN1 = b_gk_gn0_gn1_grid_desc.GetLength(I2);
|
||||
const auto GK = a_gk_gm0_gm1_grid_desc.GetLength(I0);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
|
||||
return ((GM0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I0) &&
|
||||
GM1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1) &&
|
||||
GN0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I2) &&
|
||||
GN1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3) &&
|
||||
GM0 == a_gk_gm0_gm1_grid_desc.GetLength(I1) &&
|
||||
GM1 == a_gk_gm0_gm1_grid_desc.GetLength(I2) &&
|
||||
GN0 == b_gk_gn0_gn1_grid_desc.GetLength(I1) &&
|
||||
GN1 == b_gk_gn0_gn1_grid_desc.GetLength(I2) &&
|
||||
GK == b_gk_gn0_gn1_grid_desc.GetLength(I0)) &&
|
||||
(GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK % KPerBlock == 0));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
|
||||
{
|
||||
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
|
||||
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t GM11 = GM1PerBlockGM11;
|
||||
constexpr index_t GN11 = GN1PerBlockGN11;
|
||||
|
||||
const index_t GM10 = GM1 / GM11;
|
||||
const index_t GN10 = GN1 / GN11;
|
||||
|
||||
const index_t grid_size = GM10 * GN10;
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK)
|
||||
{
|
||||
const bool has_main_k_block_loop = (GK + KPerBlock) / (2 * KPerBlock) > 1;
|
||||
|
||||
return has_main_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK)
|
||||
{
|
||||
const bool has_double_tail_k_block_loop = (GK / KPerBlock) % 2 == 0;
|
||||
|
||||
return has_double_tail_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAGKGM0GM10GM11GridDescriptor(const AGKGM0GM1GridDesc& a_gk_gm0_gm1_grid_desc)
|
||||
{
|
||||
const auto GK = a_gk_gm0_gm1_grid_desc.GetLength(I0);
|
||||
const auto GM1 = a_gk_gm0_gm1_grid_desc.GetLength(I2);
|
||||
|
||||
const auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
const auto GM10 = GM1 / GM11;
|
||||
|
||||
const auto a_gk_gm0_gm10_gm11_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
a_gk_gm0_gm1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GK),
|
||||
make_pass_through_transform(GM0),
|
||||
make_unmerge_transform(make_tuple(GM10, GM11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
return a_gk_gm0_gm10_gm11_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBGKGN0GN10GN11GridDescriptor(const BGKGN0GN1GridDesc& b_gk_gn0_gn1_grid_desc)
|
||||
{
|
||||
const auto GK = b_gk_gn0_gn1_grid_desc.GetLength(I0);
|
||||
const auto GN1 = b_gk_gn0_gn1_grid_desc.GetLength(I2);
|
||||
|
||||
const auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
const auto b_gk_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
b_gk_gn0_gn1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GK),
|
||||
make_pass_through_transform(GN0),
|
||||
make_unmerge_transform(make_tuple(GN10, GN11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
return b_gk_gn0_gn10_gn11_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(
|
||||
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
|
||||
{
|
||||
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
|
||||
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
|
||||
|
||||
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
|
||||
const auto GM10 = GM1 / GM11;
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
constexpr auto BM = GM0 * GM11;
|
||||
constexpr auto BN = GN0 * GN11;
|
||||
|
||||
constexpr auto BM1 =
|
||||
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
|
||||
constexpr auto BN1 =
|
||||
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
|
||||
|
||||
constexpr auto BM0 = BM / BM1;
|
||||
constexpr auto BN0 = BN / BN1;
|
||||
|
||||
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
c_gm0_gm1_gn0_gn1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GM0),
|
||||
make_unmerge_transform(make_tuple(GM10, GM11)),
|
||||
make_pass_through_transform(GN0),
|
||||
make_unmerge_transform(make_tuple(GN10, GN11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto c_gm10_bm_gn10_bn_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GM10),
|
||||
make_merge_transform(make_tuple(GM0, GM11)),
|
||||
make_pass_through_transform(GN10),
|
||||
make_merge_transform(make_tuple(GN0, GN11))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
c_gm10_bm_gn10_bn_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GM10),
|
||||
make_unmerge_transform(make_tuple(BM0, BM1)),
|
||||
make_pass_through_transform(GN10),
|
||||
make_unmerge_transform(make_tuple(BN0, BN1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
|
||||
|
||||
return c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeCBlockIdToGM10GN10BlockClusterAdaptor(
|
||||
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
|
||||
{
|
||||
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
|
||||
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
|
||||
|
||||
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
|
||||
const auto GM10 = GM1 / GM11;
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(GM10, GN10))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return c_blockid_to_gm10_gn10_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using AGKGM0GM10GM11GridDesc = decltype(MakeAGKGM0GM10GM11GridDescriptor(AGKGM0GM1GridDesc{}));
|
||||
using BGKGN0GN10GN11GridDesc = decltype(MakeBGKGN0GN10GN11GridDescriptor(BGKGN0GN1GridDesc{}));
|
||||
using CGM10BM0BM1GN10BN0BN1GridDesc =
|
||||
decltype(MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(CGM0GM1GN0GN1GridDesc{}));
|
||||
using CBlockIdToGM10GN10BlockClusterAdaptor =
|
||||
decltype(MakeCBlockIdToGM10GN10BlockClusterAdaptor(CGM0GM1GN0GN1GridDesc{}));
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AGKGM0GM10GM11GridDesc& a_gk_gm0_gm10_gm11_grid_desc,
|
||||
const BGKGN0GN10GN11GridDesc& b_gk_gn0_gn10_gn11_grid_desc,
|
||||
const CGM10BM0BM1GN10BN0BN1GridDesc& c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
const CBlockIdToGM10GN10BlockClusterAdaptor& c_blockid_to_gm10_gn10_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_a_grid, a_gk_gm0_gm10_gm11_grid_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_b_grid, b_gk_gn0_gn10_gn11_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_c_grid, c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetElementSpaceSize());
|
||||
|
||||
const auto GK = a_gk_gm0_gm10_gm11_grid_desc.GetLength(I0);
|
||||
|
||||
// divide block work by [GM10, GN10]
|
||||
const auto c_gm10_gn10_block_cluster_idx =
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force index data into SGPR
|
||||
const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]);
|
||||
const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]);
|
||||
|
||||
// lds max alignment
|
||||
// part of them should be moved into blockwise-gemm
|
||||
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_GM11>{},
|
||||
Number<BBlockTransferDstScalarPerVector_GN11>{},
|
||||
Number<M1PerThreadM111>{},
|
||||
Number<N1PerThreadN111>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_gk_gm0_gm10_gm11_block_desc =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_gk_gn0_gn10_gn11_block_desc =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}), max_lds_align);
|
||||
|
||||
// A matrix in LDS memory for blockwise GEMM
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_gk_bm_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory for blockwise GEMM
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_gk_bn_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}), max_lds_align);
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4<
|
||||
BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, GM0, 1, GM1PerBlockGM11>,
|
||||
ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
|
||||
ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_gk_gm0_gm10_gm11_grid_desc),
|
||||
decltype(a_gk_gm0_gm10_gm11_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
3,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_GM11,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(a_gk_gm0_gm10_gm11_grid_desc,
|
||||
make_multi_index(0, 0, igm10, 0),
|
||||
a_gk_gm0_gm10_gm11_block_desc,
|
||||
make_multi_index(0, 0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4<
|
||||
BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, GN0, 1, GN1PerBlockGN11>,
|
||||
BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
|
||||
BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_gk_gn0_gn10_gn11_grid_desc),
|
||||
decltype(b_gk_gn0_gn10_gn11_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_GN11,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_gk_gn0_gn10_gn11_grid_desc,
|
||||
make_multi_index(0, 0, ign10, 0),
|
||||
b_gk_gn0_gn10_gn11_block_desc,
|
||||
make_multi_index(0, 0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, GM1PerBlockGM11] is in LDS
|
||||
// b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS
|
||||
// c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_gk_bm_block_desc),
|
||||
decltype(b_gk_bn_block_desc),
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM1100,
|
||||
M11N11ThreadClusterN1100,
|
||||
M11N11ThreadClusterM1101,
|
||||
M11N11ThreadClusterN1101,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
constexpr auto c_bm0_bm1_bn0_bn1_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
|
||||
|
||||
constexpr auto c_bm0_bm1_bn0_bn1_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
sequence_to_tuple_of_number(c_bm0_bm1_bn0_bn1_thread_tensor_lengths));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block_double = p_shared_block;
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_bm0_bm1_bn0_bn1_thread_desc),
|
||||
decltype(c_bm0_bm1_bn0_bn1_thread_tensor_lengths)>{}
|
||||
.Run(c_bm0_bm1_bn0_bn1_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_k_m0_m1_global_iterator_hacks = AGridIteratorHacks{};
|
||||
constexpr auto b_k_n0_n1_global_iterator_hacks = BGridIteratorHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_k_m0_m1_global_move_slice_window_iterator_hack =
|
||||
AGridMoveSliceWindowIteratorHacks{};
|
||||
constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack =
|
||||
BGridMoveSliceWindowIteratorHacks{};
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_a_block_double, a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_b_block_double, b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_even_buf);
|
||||
}
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(
|
||||
a_gk_gm0_gm10_gm11_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(
|
||||
b_gk_gn0_gn10_gn11_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_bm0_bm1_bn0_bn1_thread_desc,
|
||||
a_block_even_buf,
|
||||
b_block_even_buf,
|
||||
c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(
|
||||
a_gk_gm0_gm10_gm11_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(
|
||||
b_gk_gn0_gn10_gn11_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_even_buf);
|
||||
|
||||
k_block_data_begin += 2 * KPerBlock;
|
||||
} while(k_block_data_begin < GK - 2 * KPerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_gk_gm0_gm10_gm11_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_gk_gn0_gn10_gn11_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_odd_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr index_t M11 =
|
||||
M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101;
|
||||
constexpr index_t N11 =
|
||||
N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101;
|
||||
|
||||
constexpr index_t M10 = GM1PerBlockGM11 / M11;
|
||||
constexpr index_t N10 = GN1PerBlockGN11 / N11;
|
||||
|
||||
constexpr index_t M111 = M1PerThreadM111;
|
||||
constexpr index_t N111 = N1PerThreadN111;
|
||||
|
||||
constexpr auto c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(I1,
|
||||
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0]>{},
|
||||
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1]>{},
|
||||
I1,
|
||||
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2]>{},
|
||||
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>{}));
|
||||
|
||||
const auto c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block =
|
||||
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc),
|
||||
decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc),
|
||||
Sequence<1,
|
||||
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0],
|
||||
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1],
|
||||
1,
|
||||
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2],
|
||||
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
make_multi_index(igm10,
|
||||
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I0],
|
||||
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I1],
|
||||
ign10,
|
||||
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I2],
|
||||
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I3])}
|
||||
.Run(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_grid_buf,
|
||||
CGridIteratorHacks{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -27,13 +27,13 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global,
|
||||
const FloatB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const AGlobalDesc a_k_m_global_desc,
|
||||
const BGlobalDesc b_k_n_global_desc,
|
||||
const CGlobalDesc c_m0_m1_n0_n1_global_desc,
|
||||
const CBlockClusterDesc c_block_cluster_desc)
|
||||
kernel_dynamic_gemm_v1r1(const FloatA* __restrict__ p_a_global,
|
||||
const FloatB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const AGlobalDesc a_k_m_global_desc,
|
||||
const BGlobalDesc b_k_n_global_desc,
|
||||
const CGlobalDesc c_m0_m1_n0_n1_global_desc,
|
||||
const CBlockClusterDesc c_block_cluster_desc)
|
||||
{
|
||||
GridwiseGemm::Run(p_a_global,
|
||||
p_b_global,
|
||||
@@ -63,13 +63,13 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global,
|
||||
const FloatB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const void __CONSTANT__* p_a_k_m_global_desc,
|
||||
const void __CONSTANT__* p_b_k_n_global_desc,
|
||||
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
|
||||
const void __CONSTANT__* p_c_block_cluster_desc)
|
||||
kernel_dynamic_gemm_v1r1(const FloatA* __restrict__ p_a_global,
|
||||
const FloatB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const void __CONSTANT__* p_a_k_m_global_desc,
|
||||
const void __CONSTANT__* p_b_k_n_global_desc,
|
||||
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
|
||||
const void __CONSTANT__* p_c_block_cluster_desc)
|
||||
{
|
||||
// first cast void __CONSTANT__ void* to void*
|
||||
// second cast void* to Desc*
|
||||
@@ -108,13 +108,13 @@ template <index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerThread,
|
||||
index_t NPerThread,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
index_t MLevel0Cluster,
|
||||
index_t NLevel0Cluster,
|
||||
index_t MLevel1Cluster,
|
||||
index_t NLevel1Cluster,
|
||||
index_t M1N1ThreadClusterM10,
|
||||
index_t M1N1ThreadClusterN10,
|
||||
index_t M1N1ThreadClusterM11,
|
||||
index_t M1N1ThreadClusterN11,
|
||||
typename ABlockTransferThreadSliceLengths_K_M,
|
||||
typename ABlockTransferThreadClusterLengths_K_M,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -139,14 +139,14 @@ template <index_t BlockSize,
|
||||
typename CGlobalIteratorHacks,
|
||||
typename AGlobalMoveSliceWindowIteratorHacks,
|
||||
typename BGlobalMoveSliceWindowIteratorHacks>
|
||||
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
|
||||
Number<BBlockTransferDstScalarPerVector_N>{},
|
||||
Number<MPerThread>{},
|
||||
Number<NPerThread>{});
|
||||
Number<M1PerThread>{},
|
||||
Number<N1PerThread>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
@@ -210,8 +210,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
|
||||
Number<BBlockTransferDstScalarPerVector_N>{},
|
||||
Number<MPerThread>{},
|
||||
Number<NPerThread>{});
|
||||
Number<M1PerThread>{},
|
||||
Number<N1PerThread>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
@@ -284,34 +284,39 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
|
||||
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
|
||||
"wrong!");
|
||||
static_assert(
|
||||
MPerBlock % (M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10) == 0 &&
|
||||
NPerBlock % (N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
|
||||
constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
|
||||
constexpr index_t M0PerThread =
|
||||
MPerBlock / (M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10);
|
||||
constexpr index_t N0PerThread =
|
||||
NPerBlock / (N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10);
|
||||
|
||||
constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
a_k_m_block_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<MRepeat>{}, Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{}))),
|
||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<M0PerThread>{},
|
||||
Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
b_k_n_block_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<NRepeat>{}, Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{}))),
|
||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<N0PerThread>{},
|
||||
Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
constexpr auto c_m0_m1_n0_n1_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
||||
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<M0PerThread>{},
|
||||
Number<M1PerThread>{},
|
||||
Number<N0PerThread>{},
|
||||
Number<N1PerThread>{}));
|
||||
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2<BlockSize,
|
||||
@@ -321,15 +326,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
decltype(a_k_m0_m1_block_desc),
|
||||
decltype(b_k_n0_n1_block_desc),
|
||||
decltype(c_m0_m1_n0_n1_thread_desc),
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
MLevel0Cluster,
|
||||
NLevel0Cluster,
|
||||
MLevel1Cluster,
|
||||
NLevel1Cluster,
|
||||
MPerThread,
|
||||
NPerThread>{};
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
M1PerThread,
|
||||
N1PerThread>{};
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
@@ -345,9 +350,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
|
||||
c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_m0_m1_n0_n1_thread_desc),
|
||||
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
|
||||
ThreadwiseDynamicTensorSliceSet_v1<
|
||||
FloatAcc,
|
||||
decltype(c_m0_m1_n0_n1_thread_desc),
|
||||
Sequence<M0PerThread, M1PerThread, N0PerThread, N1PerThread>>{}
|
||||
.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
|
||||
@@ -479,8 +485,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{};
|
||||
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{};
|
||||
constexpr auto M1 = Number<M1PerThread * M1N1ThreadClusterM10 * M1N1ThreadClusterM11>{};
|
||||
constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN10 * M1N1ThreadClusterN11>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
|
||||
@@ -493,7 +499,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
FloatC,
|
||||
decltype(c_m0_m1_n0_n1_thread_desc),
|
||||
decltype(c_m0_m1_n0_n1_global_desc),
|
||||
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>,
|
||||
Sequence<M0PerThread, M1PerThread, N0PerThread, N1PerThread>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
@@ -0,0 +1,621 @@
|
||||
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V1R2_HPP
|
||||
#define CK_GRIDWISE_DYNAMIC_GEMM_V1R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_v2r2.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AKM0M1GridDesc,
|
||||
typename BKN0N1GridDesc,
|
||||
typename CM0M10M11N0N10N11GridDesc,
|
||||
typename CBlockIdToM0N0BlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_v1r2(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AKM0M1GridDesc a_k_m0_m1_grid_desc,
|
||||
const BKN0N1GridDesc b_k_n0_n1_grid_desc,
|
||||
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AKMGridDesc,
|
||||
typename BKNGridDesc,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlockM1,
|
||||
index_t NPerBlockN1,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThreadM111,
|
||||
index_t N1PerThreadN111,
|
||||
index_t KPerThread,
|
||||
index_t M11N11ThreadClusterM1100,
|
||||
index_t M11N11ThreadClusterN1100,
|
||||
index_t M11N11ThreadClusterM1101,
|
||||
index_t M11N11ThreadClusterN1101,
|
||||
typename ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_M1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_N1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
struct GridwiseDynamicGemm_km_kn_mn_v1r2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M1>{},
|
||||
Number<BBlockTransferDstScalarPerVector_N1>{},
|
||||
Number<M1PerThreadM111>{},
|
||||
Number<N1PerThreadN111>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size =
|
||||
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CheckValidity(const AKMGridDesc& a_k_m_grid_desc,
|
||||
const BKNGridDesc& b_k_n_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
|
||||
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K == b_k_n_grid_desc.GetLength(I0)) &&
|
||||
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K % KPerBlock == 0);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1);
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
|
||||
|
||||
return has_main_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K)
|
||||
{
|
||||
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
|
||||
|
||||
return has_double_tail_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAKM0M1GridDescriptor(const AKMGridDesc& a_k_m_grid_desc)
|
||||
{
|
||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||
|
||||
const auto M1 = Number<MPerBlockM1>{};
|
||||
const auto M0 = M / M1;
|
||||
|
||||
const auto a_k_m0_m1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
a_k_m_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
return a_k_m0_m1_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBKN0N1GridDescriptor(const BKNGridDesc& b_k_n_grid_desc)
|
||||
{
|
||||
const auto K = b_k_n_grid_desc.GetLength(I0);
|
||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||
|
||||
const auto N1 = Number<NPerBlockN1>{};
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto b_k_n0_n1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
b_k_n_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
return b_k_n0_n1_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlockM1>{};
|
||||
constexpr auto N1 = Number<NPerBlockN1>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
constexpr auto M11 =
|
||||
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
|
||||
constexpr auto N11 =
|
||||
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
|
||||
|
||||
constexpr auto M10 = M1 / M11;
|
||||
constexpr auto N10 = N1 / N11;
|
||||
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
c_m_n_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
|
||||
make_unmerge_transform(make_tuple(N0, N10, N11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
|
||||
|
||||
return c_m0_m10_m11_n0_n10_n11_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlockM1>{};
|
||||
constexpr auto N1 = Number<NPerBlockN1>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{}));
|
||||
using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{}));
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{}));
|
||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||
decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{}));
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AKM0M1GridDesc& a_k_m0_m1_grid_desc,
|
||||
const BKN0N1GridDesc& b_k_n0_n1_grid_desc,
|
||||
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
|
||||
|
||||
const auto K = a_k_m0_m1_grid_desc.GetLength(I0);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto c_m0_n0_block_cluster_idx =
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force index data into SGPR
|
||||
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
|
||||
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M1>{},
|
||||
Number<BBlockTransferDstScalarPerVector_N1>{},
|
||||
Number<M1PerThreadM111>{},
|
||||
Number<N1PerThreadN111>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m0_m1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n0_n1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}), max_lds_align);
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, 1, MPerBlockM1>,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k_m0_m1_grid_desc),
|
||||
decltype(a_k_m0_m1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(a_k_m0_m1_grid_desc,
|
||||
make_multi_index(0, im0, 0),
|
||||
a_k_m0_m1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, 1, NPerBlockN1>,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k_n0_n1_grid_desc),
|
||||
decltype(b_k_n0_n1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_k_n0_n1_grid_desc,
|
||||
make_multi_index(0, in0, 0),
|
||||
b_k_n0_n1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlockM1] is in LDS
|
||||
// b_mtx[KPerBlocl, NPerBlockN1] is in LDS
|
||||
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k_m_block_desc),
|
||||
decltype(b_k_n_block_desc),
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM1100,
|
||||
M11N11ThreadClusterN1100,
|
||||
M11N11ThreadClusterM1101,
|
||||
M11N11ThreadClusterN1101,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
|
||||
|
||||
constexpr auto c_m10_m11_n10_n11_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size =
|
||||
math::integer_least_multiple(a_k_m0_m1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size =
|
||||
math::integer_least_multiple(b_k_n0_n1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block_double = p_shared_block;
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
|
||||
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_m10_m11_n10_n11_thread_desc),
|
||||
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
|
||||
.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_k_m0_m1_global_iterator_hacks = AGridIteratorHacks{};
|
||||
constexpr auto b_k_n0_n1_global_iterator_hacks = BGridIteratorHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_k_m0_m1_global_move_slice_window_iterator_hack =
|
||||
AGridMoveSliceWindowIteratorHacks{};
|
||||
constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack =
|
||||
BGridMoveSliceWindowIteratorHacks{};
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf =
|
||||
make_dynamic_buffer<AddressSpace::Lds>(p_a_block_double + a_block_aligned_space_size,
|
||||
a_k_m0_m1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_odd_buf =
|
||||
make_dynamic_buffer<AddressSpace::Lds>(p_b_block_double + b_block_aligned_space_size,
|
||||
b_k_n0_n1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
|
||||
}
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(
|
||||
a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(
|
||||
b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
a_block_even_buf,
|
||||
b_block_even_buf,
|
||||
c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(
|
||||
a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(
|
||||
b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
|
||||
|
||||
k_block_data_begin += 2 * KPerBlock;
|
||||
} while(k_block_data_begin < K - 2 * KPerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr index_t M11 =
|
||||
M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101;
|
||||
constexpr index_t N11 =
|
||||
N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101;
|
||||
|
||||
constexpr index_t M10 = MPerBlockM1 / M11;
|
||||
constexpr index_t N10 = NPerBlockN1 / N11;
|
||||
|
||||
constexpr index_t M111 = M1PerThreadM111;
|
||||
constexpr index_t N111 = N1PerThreadN111;
|
||||
|
||||
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(I1,
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
|
||||
I1,
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
|
||||
|
||||
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
|
||||
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
|
||||
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc),
|
||||
Sequence<1,
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
|
||||
1,
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
make_multi_index(im0,
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
|
||||
in0,
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])}
|
||||
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_grid_buf,
|
||||
CGridIteratorHacks{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "container_helper.hpp"
|
||||
#include "statically_indexed_array.hpp"
|
||||
#include "container_element_picker.hpp"
|
||||
#include "multi_index.hpp"
|
||||
#include "data_type.hpp"
|
||||
#include "float_type.hpp"
|
||||
#include "functional.hpp"
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#define CK_DEVICE_BACKEND_AMD 1
|
||||
|
||||
// GPU ID
|
||||
#if 0
|
||||
#if 1
|
||||
#define CK_AMD_GPU_GFX906 1
|
||||
#elif 0
|
||||
#define CK_AMD_GPU_GFX908 1
|
||||
@@ -28,7 +28,7 @@
|
||||
#endif
|
||||
|
||||
// launch bounds
|
||||
#define CK_USE_LAUNCH_BOUNDS 1
|
||||
#define CK_USE_LAUNCH_BOUNDS 0
|
||||
|
||||
#ifdef CK_USE_LAUNCH_BOUNDS
|
||||
#define CK_MAX_THREAD_PER_BLOCK 256
|
||||
|
||||
@@ -118,6 +118,7 @@ struct MagicDivision
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
|
||||
#if 1 // debug
|
||||
// HACK: magic division for int32_t
|
||||
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
|
||||
// non-negative for result to be correct
|
||||
@@ -127,8 +128,25 @@ struct MagicDivision
|
||||
{
|
||||
uint32_t dividend_u32 = as_type<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = ((uint64_t)dividend_u32 * (uint64_t)multiplier) >> 32;
|
||||
return (tmp + dividend_i32) >> shift;
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
#else
|
||||
// the inline ASM is producing wrong result
|
||||
__host__ __device__ static int32_t
|
||||
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t r;
|
||||
asm volatile("\n \
|
||||
v_mul_hi_u32 %0, %1, %2 \n \
|
||||
v_add_u32_e32 %0, %1, %0 \n \
|
||||
v_lshrrev_b32_e32 %0, %3, %0 \n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(as_type<uint32_t>(dividend_i32)), "s"(multiplier), "s"(shift));
|
||||
|
||||
return as_type<int32_t>(r);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -74,7 +74,7 @@ __host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
|
||||
template <class X, class Y>
|
||||
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
|
||||
{
|
||||
return (x + y - 1) / y;
|
||||
return (x + y - Number<1>{}) / y;
|
||||
}
|
||||
|
||||
template <class X, class Y>
|
||||
|
||||
@@ -26,5 +26,11 @@ __host__ __device__ constexpr auto generate_sequence_v2(F&& f, Number<N>)
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
|
||||
{
|
||||
return Sequence<Is...>{};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -99,7 +99,8 @@ __host__ __device__ void print_multi_index(const Tuple<Xs...>& x)
|
||||
printf("{");
|
||||
printf("MultiIndex, ");
|
||||
printf("size %d,", index_t{sizeof...(Xs)});
|
||||
static_for<0, sizeof...(Xs), 1>{}([&](auto i) { printf("%d ", index_t{x.At(i)}); });
|
||||
static_for<0, sizeof...(Xs), 1>{}(
|
||||
[&](auto i) { printf("%d ", static_cast<index_t>(x.At(i))); });
|
||||
printf("}");
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ install(TARGETS host LIBRARY DESTINATION lib)
|
||||
|
||||
if(DEVICE_BACKEND STREQUAL "AMD")
|
||||
set(CONV_SOURCE src/conv_driver.cpp)
|
||||
set(CONV_V2_SOURCE src/conv_driver_v2.cpp)
|
||||
set(CONV_BWD_DATA_SOURCE src/conv_bwd_data_driver.cpp)
|
||||
elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
|
||||
set(CONV_SOURCE src/conv_driver.cu)
|
||||
@@ -23,7 +24,9 @@ elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
|
||||
endif()
|
||||
|
||||
add_executable(conv_driver ${CONV_SOURCE})
|
||||
add_executable(conv_driver_v2 ${CONV_V2_SOURCE})
|
||||
add_executable(conv_bwd_data_driver ${CONV_BWD_DATA_SOURCE})
|
||||
|
||||
target_link_libraries(conv_driver PRIVATE host)
|
||||
target_link_libraries(conv_driver_v2 PRIVATE host)
|
||||
target_link_libraries(conv_bwd_data_driver PRIVATE host)
|
||||
|
||||
@@ -2,15 +2,25 @@
|
||||
#define CONV_COMMON_HPP
|
||||
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
|
||||
enum ConvTensorLayout
|
||||
{
|
||||
NCHW,
|
||||
NHWC,
|
||||
CHWN,
|
||||
NCHWc,
|
||||
NHWCc
|
||||
};
|
||||
|
||||
template <class InDesc,
|
||||
class WeiDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class LowerPads,
|
||||
class UpperPads>
|
||||
class LeftPads,
|
||||
class RightPads>
|
||||
constexpr auto get_convolution_output_default_4d_tensor_descriptor(
|
||||
InDesc, WeiDesc, ConvStrides, ConvDilations, LowerPads, UpperPads)
|
||||
InDesc, WeiDesc, ConvStrides, ConvDilations, LeftPads, RightPads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -35,21 +45,69 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(
|
||||
constexpr index_t Y = wei_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t HPadLow = LowerPads{}.Get(I0);
|
||||
constexpr index_t WPadLow = LowerPads{}.Get(I1);
|
||||
constexpr index_t LeftPadH = LeftPads{}.Get(I0);
|
||||
constexpr index_t LeftPadW = LeftPads{}.Get(I1);
|
||||
|
||||
constexpr index_t HPadUp = UpperPads{}.Get(I0);
|
||||
constexpr index_t WPadUp = UpperPads{}.Get(I1);
|
||||
constexpr index_t RightPadH = RightPads{}.Get(I0);
|
||||
constexpr index_t RightPadW = RightPads{}.Get(I1);
|
||||
|
||||
constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1;
|
||||
constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1;
|
||||
|
||||
constexpr index_t Ho = (Hi + HPadLow + HPadUp - YEff) / ConvStrides{}[0] + 1;
|
||||
constexpr index_t Wo = (Wi + WPadLow + WPadUp - XEff) / ConvStrides{}[1] + 1;
|
||||
constexpr index_t Ho = (Hi + LeftPadH + RightPadH - YEff) / ConvStrides{}[0] + 1;
|
||||
constexpr index_t Wo = (Wi + LeftPadW + RightPadW - XEff) / ConvStrides{}[1] + 1;
|
||||
|
||||
return make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
|
||||
}
|
||||
|
||||
template <typename... InDesc,
|
||||
typename... WeiDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
constexpr auto get_convolution_output_default_4d_tensor_descriptor(
|
||||
const ck::DynamicTensorDescriptor<InDesc...>& in_desc,
|
||||
const ck::DynamicTensorDescriptor<WeiDesc...>& wei_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations conv_dilations,
|
||||
const LeftPads& left_pads,
|
||||
const RightPads& right_pads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
assert(in_desc.GetNumOfDimension() == 4);
|
||||
assert(wei_desc.GetNumOfDimension() == 4);
|
||||
assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1));
|
||||
|
||||
const auto N = in_desc.GetLength(I0);
|
||||
const auto Hi = in_desc.GetLength(I2);
|
||||
const auto Wi = in_desc.GetLength(I3);
|
||||
|
||||
const auto K = wei_desc.GetLength(I0);
|
||||
const auto Y = wei_desc.GetLength(I2);
|
||||
const auto X = wei_desc.GetLength(I3);
|
||||
|
||||
const auto LeftPadH = left_pads[I0];
|
||||
const auto LeftPadW = left_pads[I1];
|
||||
|
||||
const auto RightPadH = right_pads[I0];
|
||||
const auto RightPadW = right_pads[I1];
|
||||
|
||||
const auto YEff = (Y - I1) * conv_dilations[I0] + I1;
|
||||
const auto XEff = (X - I1) * conv_dilations[I1] + I1;
|
||||
|
||||
const auto Ho = (Hi + LeftPadH + RightPadH - YEff) / conv_strides[I0] + I1;
|
||||
const auto Wo = (Wi + LeftPadW + RightPadW - XEff) / conv_strides[I1] + I1;
|
||||
|
||||
return make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo));
|
||||
}
|
||||
|
||||
template <class InDesc, class WeiDesc, class OutDesc>
|
||||
constexpr std::size_t
|
||||
calculate_convolution_flops(const InDesc& in_desc, const WeiDesc& wei_desc, const OutDesc& out_desc)
|
||||
|
||||
@@ -2,30 +2,29 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_dynamic_gemm_v1.hpp"
|
||||
#include "driver_dynamic_gemm_v1r2.hpp"
|
||||
|
||||
template <class TInWei,
|
||||
ck::index_t InWeiVectorSize,
|
||||
class TAcc,
|
||||
class TOut,
|
||||
class InDesc,
|
||||
class WeiDesc,
|
||||
class OutDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class InLeftPads,
|
||||
class InRightPads>
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
|
||||
InDesc,
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
WeiDesc,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
OutDesc,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
@@ -50,505 +49,155 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
#if 1
|
||||
// run-time variables
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths()));
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths()));
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
||||
|
||||
const auto conv_strides = to_multi_index(ConvStrides{});
|
||||
const auto conv_dilations = to_multi_index(ConvDilations{});
|
||||
const auto in_left_pads = to_multi_index(InLeftPads{});
|
||||
const auto in_right_pads = to_multi_index(InRightPads{});
|
||||
#else
|
||||
// compile-time variables
|
||||
const auto in_n_c_hi_wi_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
sequence_to_tuple_of_number(InDesc::GetLengths()));
|
||||
const auto wei_k_c_y_x_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
sequence_to_tuple_of_number(WeiDesc::GetLengths()));
|
||||
const auto out_n_k_ho_wo_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
sequence_to_tuple_of_number(OutDesc::GetLengths()));
|
||||
|
||||
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
|
||||
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
|
||||
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
|
||||
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
// cdata = 16, BlockSize = 64, 16x64x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 2;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
|
||||
#elif 0
|
||||
// cdata = 32, BlockSize 64, 16x128x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize 64, 16x256x2
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 1;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize 64, 16x256x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 1;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
|
||||
#elif 0
|
||||
// cdata = 16, BlockSize = 64, 16x64x4
|
||||
// GemmBBlockCopySrcDataPerRead_GemmN = 4
|
||||
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 2
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 2;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
|
||||
#elif 0
|
||||
// cdata = 32, BlockSize = 64, 16x128x4
|
||||
// GemmBBlockCopySrcDataPerRead_GemmN = 4
|
||||
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 32x256x8
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 32;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 256, 128x128x2
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 256, 128x128x4
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 1
|
||||
#if 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
// b thread copy 4x1
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_N1 = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
// b thread copy 2x2
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x16
|
||||
// GemmBBlockCopySrcDataPerRead_GemmN = 4
|
||||
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
|
||||
const auto descs =
|
||||
#if 1
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad
|
||||
#elif 0
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad
|
||||
#else
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1
|
||||
#endif
|
||||
<GemmMPerBlock, GemmNPerBlock, GemmM1, GemmN1>(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
const auto wei_gemmk_gemmm_grid_desc = descs[I0];
|
||||
const auto in_gemmk_gemmn_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = launch_kernel_dynamic_gemm_v1<
|
||||
float ave_time = driver_dynamic_gemm_v1r2<
|
||||
BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(descs[I0]),
|
||||
decltype(descs[I1]),
|
||||
decltype(descs[I2]),
|
||||
decltype(descs[I3]),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
decltype(wei_gemmk_gemmm_grid_desc),
|
||||
decltype(in_gemmk_gemmn_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlockM1,
|
||||
GemmNPerBlockN1,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmM1PerThreadM111,
|
||||
GemmN1PerThreadN111,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmABlockTransferDstScalarPerVector_GemmM,
|
||||
GemmM11N11ThreadClusterM1100,
|
||||
GemmM11N11ThreadClusterN1100,
|
||||
GemmM11N11ThreadClusterM1101,
|
||||
GemmM11N11ThreadClusterN1101,
|
||||
GemmABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
GemmABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
Sequence<2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<2, 1, 0>, // ABlockTransferSrcAccessOrder
|
||||
0, // ABlockTransferSrcVectorDim
|
||||
GemmABlockTransferSrcScalarPerVector_K,
|
||||
GemmABlockTransferDstScalarPerVector_M1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmN,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<2, 3, 0, 1>,
|
||||
3,
|
||||
GemmCThreadTransferDstScalarPerVector_GemmN1,
|
||||
decltype(descs[I4]),
|
||||
decltype(descs[I5]),
|
||||
decltype(descs[I6]),
|
||||
decltype(descs[I7]),
|
||||
decltype(descs[I8])>(static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
descs[I0],
|
||||
descs[I1],
|
||||
descs[I2],
|
||||
descs[I3],
|
||||
descs[I4],
|
||||
descs[I5],
|
||||
descs[I6],
|
||||
descs[I7],
|
||||
descs[I8],
|
||||
nrepeat);
|
||||
GemmBBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
GemmBBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
Sequence<0, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<0, 1, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
GemmBBlockTransferSrcScalarPerVector_N1,
|
||||
GemmBBlockTransferDstScalarPerVector_N1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_N11,
|
||||
decltype(wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_iterator_hacks),
|
||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
|
||||
decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk_gemmm_grid_desc,
|
||||
in_gemmk_gemmn_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_iterator_hacks,
|
||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks,
|
||||
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
|
||||
|
||||
@@ -2,30 +2,29 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_dynamic_gemm_v1.hpp"
|
||||
#include "driver_dynamic_gemm_v1r2.hpp"
|
||||
|
||||
template <class TInWei,
|
||||
ck::index_t InWeiVectorSize,
|
||||
class TAcc,
|
||||
class TOut,
|
||||
class InDesc,
|
||||
class WeiDesc,
|
||||
class OutDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class InLeftPads,
|
||||
class InRightPads>
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
|
||||
InDesc,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
WeiDesc,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
OutDesc,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
@@ -42,73 +41,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
constexpr auto N = OutDesc::GetLengths()[I0];
|
||||
constexpr auto K = OutDesc::GetLengths()[I1];
|
||||
constexpr auto C = WeiDesc::GetLengths()[I1];
|
||||
|
||||
constexpr auto Hi = InDesc::GetLengths()[I2];
|
||||
constexpr auto Wi = InDesc::GetLengths()[I3];
|
||||
|
||||
constexpr auto Ho = OutDesc::GetLengths()[I2];
|
||||
constexpr auto Wo = OutDesc::GetLengths()[I3];
|
||||
|
||||
constexpr auto Y = WeiDesc::GetLengths()[I2];
|
||||
constexpr auto X = WeiDesc::GetLengths()[I3];
|
||||
|
||||
constexpr auto C0 = C / Number<InWeiVectorSize>{};
|
||||
constexpr auto C1 = Number<InWeiVectorSize>{};
|
||||
|
||||
#if 1
|
||||
// run-time variables
|
||||
constexpr auto in_n_hi_wi_c0_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0));
|
||||
constexpr auto wei_k_y_x_c0_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(K, Y, X, C0));
|
||||
constexpr auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Ho, Wo, K));
|
||||
|
||||
const auto conv_strides = to_multi_index(ConvStrides{});
|
||||
const auto conv_dilations = to_multi_index(ConvDilations{});
|
||||
const auto in_left_pads = to_multi_index(InLeftPads{});
|
||||
const auto in_right_pads = to_multi_index(InRightPads{});
|
||||
#else
|
||||
// compile-time variables
|
||||
constexpr auto in_n_hi_wi_c0_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Hi, Wi, C0));
|
||||
constexpr auto wei_k_y_x_c0_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y, X, C0));
|
||||
constexpr auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Ho, Wo, K));
|
||||
|
||||
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
|
||||
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
|
||||
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
|
||||
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
|
||||
#endif
|
||||
|
||||
Tensor<TInWei> in_n_hi_wi_c(
|
||||
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{})));
|
||||
Tensor<TInWei> wei_k_y_x_c(
|
||||
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{})));
|
||||
Tensor<TOut> out_n_ho_wo_k(
|
||||
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{})));
|
||||
|
||||
auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) {
|
||||
in_n_hi_wi_c(n, hi, wi, c) = in_n_c_hi_wi(n, c, hi, wi);
|
||||
};
|
||||
|
||||
auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) {
|
||||
wei_k_y_x_c(k, y, x, c) = wei_k_c_y_x(k, c, y, x);
|
||||
};
|
||||
|
||||
auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) {
|
||||
out_n_ho_wo_k(n, ho, wo, k) = out_n_k_ho_wo(n, k, ho, wo);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)();
|
||||
make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)();
|
||||
make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)();
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
@@ -117,357 +49,472 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// cdata = 16, BlockSize = 64, 16x64x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmMPerBlockM1 = 16;
|
||||
constexpr index_t GemmNPerBlockN1 = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 2;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmM1PerThreadM111 = 2;
|
||||
constexpr index_t GemmN1PerThreadN111 = 2;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 2;
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 16>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 2;
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 2;
|
||||
#elif 0
|
||||
// cdata = 32, BlockSize = 64, 16x128x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
constexpr index_t GemmMPerBlockM1 = 16;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmM1PerThreadM111 = 2;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 16>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 2;
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 2;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 64, 16x256x2
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 2;
|
||||
constexpr index_t GemmMPerBlockM1 = 16;
|
||||
constexpr index_t GemmNPerBlockN1 = 256;
|
||||
constexpr index_t GemmKPerBlock = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 1;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 1;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 16>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 16>;
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<2, 1, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 64, 16x256x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
constexpr index_t GemmMPerBlockM1 = 16;
|
||||
constexpr index_t GemmNPerBlockN1 = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 1;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 1;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 16>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 32x256x4
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 32;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
constexpr index_t GemmMPerBlockM1 = 32;
|
||||
constexpr index_t GemmNPerBlockN1 = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 32>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>;
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 128>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 32x256x8
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 32;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerBlockM1 = 32;
|
||||
constexpr index_t GemmNPerBlockN1 = 256;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<2, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 32>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>;
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<8, 1, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 128>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 64>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 2;
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<8, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
|
||||
#endif
|
||||
|
||||
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
|
||||
const auto descs =
|
||||
#if 1
|
||||
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
#if 0
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
#else
|
||||
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
#endif
|
||||
<GemmMPerBlock, GemmNPerBlock, GemmM1, GemmN1>(wei_k_y_x_c0_desc,
|
||||
in_n_hi_wi_c0_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
#else
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1(wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
#endif
|
||||
|
||||
const auto wei_gemmk_gemmm_grid_desc = descs[I0];
|
||||
const auto in_gemmk_gemmn_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = launch_kernel_dynamic_gemm_v1<
|
||||
float ave_time = driver_dynamic_gemm_v1r2<
|
||||
BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(descs[I0]),
|
||||
decltype(descs[I1]),
|
||||
decltype(descs[I2]),
|
||||
decltype(descs[I3]),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
decltype(wei_gemmk_gemmm_grid_desc),
|
||||
decltype(in_gemmk_gemmn_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlockM1,
|
||||
GemmNPerBlockN1,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmM1PerThreadM111,
|
||||
GemmN1PerThreadN111,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmABlockTransferDstScalarPerVector_GemmM,
|
||||
GemmM11N11ThreadClusterM1100,
|
||||
GemmM11N11ThreadClusterN1100,
|
||||
GemmM11N11ThreadClusterM1101,
|
||||
GemmM11N11ThreadClusterN1101,
|
||||
GemmABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
GemmABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
Sequence<1, 2, 0>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<1, 2, 0>, // ABlockTransferSrcAccessOrder
|
||||
0, // ABlockTransferSrcVectorDim
|
||||
GemmABlockTransferSrcScalarPerVector_K,
|
||||
GemmABlockTransferDstScalarPerVector_M1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmN,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<2, 3, 0, 1>,
|
||||
1,
|
||||
GemmCThreadTransferDstScalarPerVector_GemmM1,
|
||||
decltype(descs[I4]),
|
||||
decltype(descs[I5]),
|
||||
decltype(descs[I6]),
|
||||
decltype(descs[I7]),
|
||||
decltype(descs[I8])>(static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
descs[I0],
|
||||
descs[I1],
|
||||
descs[I2],
|
||||
descs[I3],
|
||||
descs[I4],
|
||||
descs[I5],
|
||||
descs[I6],
|
||||
descs[I7],
|
||||
descs[I8],
|
||||
nrepeat);
|
||||
GemmBBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
GemmBBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
Sequence<1, 2, 0>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<1, 2, 0>, // BBlockTransferSrcAccessOrder
|
||||
0, // BBlockTransferSrcVectorDim
|
||||
GemmBBlockTransferSrcScalarPerVector_K,
|
||||
GemmBBlockTransferDstScalarPerVector_N1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
2, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_M11,
|
||||
decltype(wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_iterator_hacks),
|
||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
|
||||
decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk_gemmm_grid_desc,
|
||||
in_gemmk_gemmn_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_iterator_hacks,
|
||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks,
|
||||
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
out_n_k_ho_wo(n, k, ho, wo) = out_n_ho_wo_k(n, ho, wo, k);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)();
|
||||
}
|
||||
|
||||
@@ -0,0 +1,240 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_dynamic_contraction_v1r1.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 1
|
||||
// cdata = 64, BlockSize = 256, [8, 1, 128] * [8, 4, 32] = [1, 128, 4, 32]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t N0 = 4;
|
||||
|
||||
constexpr index_t GemmGM1PerBlockGM11 = 128;
|
||||
constexpr index_t GemmGN1PerBlockGN11 = 32;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11 = Sequence<4, 1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11 = Sequence<2, 1, 1, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GM11 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11 = Sequence<1, 4, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11 = Sequence<8, 1, 1, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GN11 = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GN11 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, [8, 1, 128] * [8, 8, 16] = [1, 128, 8, 16]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t N0 = 8;
|
||||
|
||||
constexpr index_t GemmGM1PerBlockGM11 = 128;
|
||||
constexpr index_t GemmGN1PerBlockGN11 = 16;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11 = Sequence<4, 1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11 = Sequence<2, 1, 1, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GM11 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11 = Sequence<1, 4, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11 = Sequence<8, 2, 1, 16>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GN11 = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GN11 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1;
|
||||
#endif
|
||||
|
||||
const auto descs = transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad<N0>(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
const auto wei_gk_gm0_gm1_grid_desc = descs[I0];
|
||||
const auto in_gk_gn0_gn1_grid_desc = descs[I1];
|
||||
const auto out_gm0_gm1_gn0_gn1_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gk_gm0_gm10_gm11_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gk_gn0_gn10_gn11_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_gm10_bm0_bm1_gn10_bn0_bn1_grid_iterator_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gk_gm0_gm10_gm11_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gk_gn0_gn10_gn11_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_contraction_v1r1<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(wei_gk_gm0_gm1_grid_desc),
|
||||
decltype(in_gk_gn0_gn1_grid_desc),
|
||||
decltype(out_gm0_gm1_gn0_gn1_grid_desc),
|
||||
GemmGM1PerBlockGM11,
|
||||
GemmGN1PerBlockGN11,
|
||||
GemmKPerBlock,
|
||||
GemmM1PerThreadM111,
|
||||
GemmN1PerThreadN111,
|
||||
GemmKPerThread,
|
||||
GemmM11N11ThreadClusterM1100,
|
||||
GemmM11N11ThreadClusterN1100,
|
||||
GemmM11N11ThreadClusterM1101,
|
||||
GemmM11N11ThreadClusterN1101,
|
||||
GemmABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
|
||||
GemmABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
|
||||
Sequence<3, 2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<3, 2, 1, 0>, // ABlockTransferSrcAccessOrder
|
||||
0, // ABlockTransferSrcVectorDim
|
||||
GemmABlockTransferSrcScalarPerVector_GK,
|
||||
GemmABlockTransferDstScalarPerVector_GM11,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
|
||||
GemmBBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
|
||||
Sequence<0, 3, 2, 1>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<0, 3, 2, 1>, // BBlockTransferSrcAccessOrder
|
||||
3, // BBlockTransferSrcVectorDim
|
||||
GemmBBlockTransferSrcScalarPerVector_GN11,
|
||||
GemmBBlockTransferDstScalarPerVector_GN11,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_BN1,
|
||||
decltype(wei_gk_gm0_gm10_gm11_grid_iterator_hacks),
|
||||
decltype(in_gk_gn0_gn10_gn11_grid_iterator_hacks),
|
||||
decltype(out_gm10_bm0_bm1_gn10_bn0_bn1_grid_iterator_hacks),
|
||||
decltype(wei_gk_gm0_gm10_gm11_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_gk_gn0_gn10_gn11_grid_move_slice_window_iterator_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_gk_gm0_gm1_grid_desc,
|
||||
in_gk_gn0_gn1_grid_desc,
|
||||
out_gm0_gm1_gn0_gn1_grid_desc,
|
||||
wei_gk_gm0_gm10_gm11_grid_iterator_hacks,
|
||||
in_gk_gn0_gn10_gn11_grid_iterator_hacks,
|
||||
out_gm10_bm0_bm1_gn10_bn0_bn1_grid_iterator_hacks,
|
||||
wei_gk_gm0_gm10_gm11_grid_move_slice_window_iterator_hacks,
|
||||
in_gk_gn0_gn10_gn11_grid_move_slice_window_iterator_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -4,97 +4,64 @@
|
||||
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp"
|
||||
|
||||
template <class TInWei,
|
||||
template <typename TInWei,
|
||||
ck::index_t InWeiVectorSize,
|
||||
class TAcc,
|
||||
class TOut,
|
||||
class InDesc,
|
||||
class WeiDesc,
|
||||
class OutDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class InLeftPads,
|
||||
class InRightPads>
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
|
||||
InDesc,
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
WeiDesc,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
OutDesc,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw"
|
||||
<< std::endl;
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto N = OutDesc::GetLengths()[I0];
|
||||
constexpr auto K = OutDesc::GetLengths()[I1];
|
||||
constexpr auto C = WeiDesc::GetLengths()[I1];
|
||||
const auto N = out_n_k_ho_wo_lengths[I0];
|
||||
const auto K = out_n_k_ho_wo_lengths[I1];
|
||||
const auto C = wei_k_c_y_x_lengths[I1];
|
||||
|
||||
constexpr auto Hi = InDesc::GetLengths()[I2];
|
||||
constexpr auto Wi = InDesc::GetLengths()[I3];
|
||||
const auto Hi = in_n_c_hi_wi_lengths[I2];
|
||||
const auto Wi = in_n_c_hi_wi_lengths[I3];
|
||||
|
||||
constexpr auto Ho = OutDesc::GetLengths()[I2];
|
||||
constexpr auto Wo = OutDesc::GetLengths()[I3];
|
||||
const auto Ho = out_n_k_ho_wo_lengths[I2];
|
||||
const auto Wo = out_n_k_ho_wo_lengths[I3];
|
||||
|
||||
constexpr auto Y = WeiDesc::GetLengths()[I2];
|
||||
constexpr auto X = WeiDesc::GetLengths()[I3];
|
||||
const auto Y = wei_k_c_y_x_lengths[I2];
|
||||
const auto X = wei_k_c_y_x_lengths[I3];
|
||||
|
||||
constexpr auto C0 = C / Number<InWeiVectorSize>{};
|
||||
constexpr auto C1 = Number<InWeiVectorSize>{};
|
||||
const auto C0 = C / Number<InWeiVectorSize>{};
|
||||
const auto C1 = Number<InWeiVectorSize>{};
|
||||
|
||||
constexpr auto K0 = K / Number<InWeiVectorSize>{};
|
||||
constexpr auto K1 = Number<InWeiVectorSize>{};
|
||||
const auto K0 = K / Number<InWeiVectorSize>{};
|
||||
const auto K1 = Number<InWeiVectorSize>{};
|
||||
|
||||
#if 0
|
||||
// run-time variables
|
||||
const auto in_n_c0_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, C0, Hi, Wi));
|
||||
const auto wei_k_c0_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(K, C0, Y, X));
|
||||
const auto out_n_k0_ho_wo_k1_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, K0, Ho, Wo, K1));
|
||||
|
||||
const auto conv_strides = to_multi_index(ConvStrides{});
|
||||
const auto conv_dilations = to_multi_index(ConvDilations{});
|
||||
const auto in_left_pads = to_multi_index(InLeftPads{});
|
||||
const auto in_right_pads = to_multi_index(InRightPads{});
|
||||
#else
|
||||
// compile-time variables
|
||||
const auto in_n_c0_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi));
|
||||
const auto wei_k_c0_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X));
|
||||
const auto out_n_k0_ho_wo_k1_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1));
|
||||
|
||||
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
|
||||
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
|
||||
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
|
||||
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
|
||||
#endif
|
||||
|
||||
Tensor<TInWei> in_n_c0_hi_wi_c1(make_HostTensorDescriptor(
|
||||
make_native_tensor_descriptor_packed(Sequence<N, C0, Hi, Wi, C1>{})));
|
||||
Tensor<TInWei> wei_k_c0_y_x_c1(make_HostTensorDescriptor(
|
||||
make_native_tensor_descriptor_packed(Sequence<K, C0, Y, X, C1>{})));
|
||||
Tensor<TOut> out_n_k0_ho_wo_k1(make_HostTensorDescriptor(
|
||||
make_native_tensor_descriptor_packed(Sequence<N, K0, Ho, Wo, K1>{})));
|
||||
Tensor<TInWei> in_n_c0_hi_wi_c1(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{N, C0, Hi, Wi, C1}));
|
||||
Tensor<TInWei> wei_k_c0_y_x_c1(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{K, C0, Y, X, C1}));
|
||||
Tensor<TOut> out_n_k0_ho_wo_k1(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{N, K0, Ho, Wo, K1}));
|
||||
|
||||
auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) {
|
||||
in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) =
|
||||
@@ -109,17 +76,30 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
|
||||
make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)();
|
||||
make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)();
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
|
||||
DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) *
|
||||
in_n_c0_hi_wi_c1.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) *
|
||||
out_n_k0_ho_wo_k1.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
|
||||
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
|
||||
|
||||
const auto in_n_c0_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi));
|
||||
const auto wei_k_c0_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X));
|
||||
const auto out_n_k0_ho_wo_k1_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1));
|
||||
|
||||
#if 1
|
||||
// cdata = 64, BlockSize = 64, 16x8x32x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = K;
|
||||
constexpr index_t KPerBlock = 16;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
constexpr index_t EPerBlock = C0;
|
||||
constexpr index_t EPerBlock = 1;
|
||||
|
||||
constexpr index_t KPerThread = KPerBlock;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
@@ -134,7 +114,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
|
||||
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_W = K1;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_W = 16;
|
||||
|
||||
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
|
||||
#else
|
||||
@@ -165,17 +145,28 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
|
||||
|
||||
constexpr auto conv_driver =
|
||||
#if 0
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
|
||||
#else
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad<
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
|
||||
#endif
|
||||
BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type, TAcc, TOut, KPerBlock,
|
||||
HoPerBlock, WoPerBlock, EPerBlock, KPerThread, HoPerThread, WoPerThread,
|
||||
EPerThread, ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K, ABlockTransferSrcScalarPerVector_E,
|
||||
ABlockTransferDstScalarPerVector_K, BThreadTransferSrcScalarPerVector_W,
|
||||
CThreadTransferDstScalarPerVector_W > {};
|
||||
<BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TAcc,
|
||||
TOut,
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
EPerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
ABlockTransferSrcScalarPerVector_E,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
BThreadTransferSrcScalarPerVector_W,
|
||||
CThreadTransferDstScalarPerVector_W>{};
|
||||
|
||||
conv_driver.Run(wei_k_c0_y_x_desc,
|
||||
in_n_c0_hi_wi_desc,
|
||||
@@ -185,12 +176,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()));
|
||||
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()));
|
||||
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
|
||||
out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
|
||||
|
||||
auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
out_n_k_ho_wo(n, k, ho, wo) =
|
||||
|
||||
@@ -6,58 +6,94 @@ template <class TIn,
|
||||
class TOut,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class LowerPads,
|
||||
class UpperPads>
|
||||
void host_direct_convolution(const Tensor<TIn>& in_nchw,
|
||||
const Tensor<TWei>& wei_kcyx,
|
||||
Tensor<TOut>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LowerPads,
|
||||
UpperPads)
|
||||
class InLeftPads,
|
||||
class InRightPads>
|
||||
void host_direct_convolution(const Tensor<TIn>& in,
|
||||
const Tensor<TWei>& wei,
|
||||
Tensor<TOut>& out,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const ConvTensorLayout layout = ConvTensorLayout::NCHW)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
|
||||
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
auto f = [&](auto n, auto k, auto ho, auto wo) {
|
||||
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
double v = 0;
|
||||
for(int c = 0; c < wei_kcyx.mDesc.GetLengths()[1]; ++c)
|
||||
for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c)
|
||||
{
|
||||
for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y)
|
||||
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
|
||||
{
|
||||
int hi = ho * ConvStrides{}[0] + y * ConvDilations{}[0] - h_pad_low;
|
||||
for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x)
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x)
|
||||
{
|
||||
int wi = wo * ConvStrides{}[1] + x * ConvDilations{}[1] - w_pad_low;
|
||||
if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in_nchw.mDesc.GetLengths()[3])
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[3])
|
||||
{
|
||||
v += static_cast<const double>(in_nchw(n, c, hi, wi)) *
|
||||
static_cast<const double>(wei_kcyx(k, c, y, x));
|
||||
v += static_cast<const double>(in(n, c, hi, wi)) *
|
||||
static_cast<const double>(wei(k, c, y, x));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out_nkhw(n, k, ho, wo) = v;
|
||||
out(n, k, ho, wo) = v;
|
||||
};
|
||||
|
||||
auto f_par = make_ParallelTensorFunctor(f,
|
||||
out_nkhw.mDesc.GetLengths()[0],
|
||||
out_nkhw.mDesc.GetLengths()[1],
|
||||
out_nkhw.mDesc.GetLengths()[2],
|
||||
out_nkhw.mDesc.GetLengths()[3]);
|
||||
auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) {
|
||||
double v = 0;
|
||||
for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c)
|
||||
{
|
||||
for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int x = 0; x < wei.mDesc.GetLengths()[2]; ++x)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[2])
|
||||
{
|
||||
v += static_cast<const double>(in(n, hi, wi, c)) *
|
||||
static_cast<const double>(wei(k, y, x, c));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out(n, ho, wo, k) = v;
|
||||
};
|
||||
|
||||
f_par(std::thread::hardware_concurrency());
|
||||
switch(layout)
|
||||
{
|
||||
case ConvTensorLayout::NCHW:
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
break;
|
||||
case ConvTensorLayout::NHWC:
|
||||
make_ParallelTensorFunctor(f_nhwc,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
break;
|
||||
default: throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
|
||||
template <class TIn, class TWei, class TOut, class LowerPads, class UpperPads>
|
||||
template <class TIn, class TWei, class TOut, class InLeftPads, class InRightPads>
|
||||
void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
|
||||
const Tensor<TWei>& wei_kcyx,
|
||||
Tensor<TOut>& out_nkhw,
|
||||
LowerPads,
|
||||
UpperPads)
|
||||
InLeftPads,
|
||||
InRightPads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -76,8 +112,8 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
|
||||
std::size_t HO = out_nkhw.mDesc.GetLengths()[2];
|
||||
std::size_t WO = out_nkhw.mDesc.GetLengths()[3];
|
||||
|
||||
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
|
||||
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
|
||||
index_t h_pad_low = InLeftPads{}.Get(Number<0>{});
|
||||
index_t w_pad_low = InLeftPads{}.Get(Number<1>{});
|
||||
|
||||
std::size_t HiPerTile = HoPerTile + Y - 1;
|
||||
std::size_t WiPerTile = WoPerTile + X - 1;
|
||||
|
||||
@@ -271,19 +271,20 @@ struct Tensor
|
||||
std::vector<T> mData;
|
||||
};
|
||||
|
||||
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout)
|
||||
template <typename X>
|
||||
HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens) : mLens(lens)
|
||||
{
|
||||
os << "dim " << desc.GetNumOfDimension() << ", ";
|
||||
|
||||
os << "lengths {";
|
||||
LogRange(os, desc.GetLengths(), ", ");
|
||||
os << "}, ";
|
||||
|
||||
os << "strides {";
|
||||
LogRange(os, desc.GetStrides(), ", ");
|
||||
os << "}" << std::endl;
|
||||
this->CalculateStrides();
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens, std::vector<Y> strides)
|
||||
: mLens(lens), mStrides(strides)
|
||||
{
|
||||
}
|
||||
|
||||
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout);
|
||||
|
||||
template <class T>
|
||||
void check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
{
|
||||
|
||||
@@ -44,7 +44,7 @@ struct GeneratorTensor_Checkboard
|
||||
template <class... Ts>
|
||||
double operator()(Ts... Xs) const
|
||||
{
|
||||
std::array<ck::index_t, sizeof...(Ts)> dims = {{Xs...}};
|
||||
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
|
||||
return std::accumulate(dims.begin(),
|
||||
dims.end(),
|
||||
true,
|
||||
|
||||
@@ -14,19 +14,41 @@
|
||||
#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
if(argc != 5)
|
||||
{
|
||||
printf("arg1: do_verification, arg2: do_log, arg3: init_method, arg4: nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const bool do_verification = atoi(argv[1]);
|
||||
const int init_method = atoi(argv[2]);
|
||||
const bool do_log = atoi(argv[3]);
|
||||
const int nrepeat = atoi(argv[4]);
|
||||
|
||||
#if 0
|
||||
constexpr index_t N = 8;
|
||||
constexpr index_t C = 8;
|
||||
constexpr index_t Hi = 4;
|
||||
constexpr index_t Wi = 8;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 1080;
|
||||
constexpr index_t WI = 1920;
|
||||
constexpr index_t Hi = 540;
|
||||
constexpr index_t Wi = 960;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -34,13 +56,13 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 540;
|
||||
constexpr index_t WI = 960;
|
||||
constexpr index_t Hi = 270;
|
||||
constexpr index_t Wi = 480;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -48,27 +70,13 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 270;
|
||||
constexpr index_t WI = 480;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 1080;
|
||||
constexpr index_t WI = 1920;
|
||||
constexpr index_t Hi = 1080;
|
||||
constexpr index_t Wi = 1920;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -76,13 +84,13 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 1;
|
||||
constexpr index_t HI = 1024;
|
||||
constexpr index_t WI = 2048;
|
||||
constexpr index_t Hi = 1024;
|
||||
constexpr index_t Wi = 2048;
|
||||
constexpr index_t K = 4;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -90,13 +98,13 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 540;
|
||||
constexpr index_t WI = 960;
|
||||
constexpr index_t Hi = 540;
|
||||
constexpr index_t Wi = 960;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -104,13 +112,13 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 270;
|
||||
constexpr index_t WI = 480;
|
||||
constexpr index_t Hi = 270;
|
||||
constexpr index_t Wi = 480;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -118,14 +126,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 3x3, 36x36, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 37;
|
||||
constexpr index_t WI = 37;
|
||||
constexpr index_t Hi = 37;
|
||||
constexpr index_t Wi = 37;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -133,14 +141,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 35x35, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t Hi = 35;
|
||||
constexpr index_t Wi = 35;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -148,14 +156,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 71x71
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 71;
|
||||
constexpr index_t WI = 71;
|
||||
constexpr index_t Hi = 71;
|
||||
constexpr index_t Wi = 71;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -163,14 +171,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 1
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 1x1, 8x8
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1536;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t Hi = 8;
|
||||
constexpr index_t Wi = 8;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -178,14 +186,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 73x73
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 160;
|
||||
constexpr index_t HI = 73;
|
||||
constexpr index_t WI = 73;
|
||||
constexpr index_t Hi = 73;
|
||||
constexpr index_t Wi = 73;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -193,14 +201,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 35x35
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 96;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t Hi = 35;
|
||||
constexpr index_t Wi = 35;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -208,14 +216,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 1
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 3x3, 71x71
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 71;
|
||||
constexpr index_t WI = 71;
|
||||
constexpr index_t Hi = 71;
|
||||
constexpr index_t Wi = 71;
|
||||
constexpr index_t K = 192;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -223,14 +231,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 7x1, 17x17
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t Hi = 17;
|
||||
constexpr index_t Wi = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 1;
|
||||
@@ -238,14 +246,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
#elif 0
|
||||
using InLeftPads = Sequence<3, 0>;
|
||||
using InRightPads = Sequence<3, 0>;
|
||||
#elif 1
|
||||
// 1x7, 17x17
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t Hi = 17;
|
||||
constexpr index_t Wi = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 7;
|
||||
@@ -253,14 +261,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 3>;
|
||||
using RightPads = Sequence<0, 3>;
|
||||
using InLeftPads = Sequence<0, 3>;
|
||||
using InRightPads = Sequence<0, 3>;
|
||||
#elif 0
|
||||
// 3x3, 299x299 stride=2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 3;
|
||||
constexpr index_t HI = 299;
|
||||
constexpr index_t WI = 299;
|
||||
constexpr index_t Hi = 299;
|
||||
constexpr index_t Wi = 299;
|
||||
constexpr index_t K = 32;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -268,14 +276,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 147x147
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 147;
|
||||
constexpr index_t WI = 147;
|
||||
constexpr index_t Hi = 147;
|
||||
constexpr index_t Wi = 147;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -283,14 +291,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 3x3, 149x149
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 32;
|
||||
constexpr index_t HI = 149;
|
||||
constexpr index_t WI = 149;
|
||||
constexpr index_t Hi = 149;
|
||||
constexpr index_t Wi = 149;
|
||||
constexpr index_t K = 32;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -298,14 +306,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 17x17, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t Hi = 17;
|
||||
constexpr index_t Wi = 17;
|
||||
constexpr index_t K = 192;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -313,14 +321,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 35x35
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 384;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t Hi = 35;
|
||||
constexpr index_t Wi = 35;
|
||||
constexpr index_t K = 96;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -328,14 +336,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 35x35, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 288;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t Hi = 35;
|
||||
constexpr index_t Wi = 35;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -343,14 +351,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x3, 8x8
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 384;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t Hi = 8;
|
||||
constexpr index_t Wi = 8;
|
||||
constexpr index_t K = 448;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 3;
|
||||
@@ -358,14 +366,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 1>;
|
||||
using RightPads = Sequence<0, 1>;
|
||||
using InLeftPads = Sequence<0, 1>;
|
||||
using InRightPads = Sequence<0, 1>;
|
||||
#elif 0
|
||||
// 3x1, 8x8
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 448;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t Hi = 8;
|
||||
constexpr index_t Wi = 8;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 1;
|
||||
@@ -373,14 +381,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 0>;
|
||||
using RightPads = Sequence<1, 0>;
|
||||
using InLeftPads = Sequence<1, 0>;
|
||||
using InRightPads = Sequence<1, 0>;
|
||||
#elif 0
|
||||
// 3x3, 147x147
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 147;
|
||||
constexpr index_t WI = 147;
|
||||
constexpr index_t Hi = 147;
|
||||
constexpr index_t Wi = 147;
|
||||
constexpr index_t K = 96;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -388,14 +396,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 7x1, 73x73
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 73;
|
||||
constexpr index_t WI = 73;
|
||||
constexpr index_t Hi = 73;
|
||||
constexpr index_t Wi = 73;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 1;
|
||||
@@ -403,14 +411,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
using InLeftPads = Sequence<3, 0>;
|
||||
using InRightPads = Sequence<3, 0>;
|
||||
#elif 0
|
||||
// 3x3, 73x73
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 73;
|
||||
constexpr index_t WI = 73;
|
||||
constexpr index_t Hi = 73;
|
||||
constexpr index_t Wi = 73;
|
||||
constexpr index_t K = 96;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -418,14 +426,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 14x14, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t Hi = 14;
|
||||
constexpr index_t Wi = 14;
|
||||
constexpr index_t K = 2048;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -433,14 +441,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 14x14
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t Hi = 14;
|
||||
constexpr index_t Wi = 14;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -448,14 +456,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 14x14, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t Hi = 14;
|
||||
constexpr index_t Wi = 14;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -463,14 +471,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
// 3x3, 28x28
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t Hi = 28;
|
||||
constexpr index_t Wi = 28;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -478,14 +486,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 1
|
||||
// 3x3, 14x14
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t Hi = 14;
|
||||
constexpr index_t Wi = 14;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -493,14 +501,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 1x1, 56x56, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t Hi = 56;
|
||||
constexpr index_t Wi = 56;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -508,14 +516,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 7x7, 230x230 stride=2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 3;
|
||||
constexpr index_t HI = 230;
|
||||
constexpr index_t WI = 230;
|
||||
constexpr index_t Hi = 230;
|
||||
constexpr index_t Wi = 230;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 7;
|
||||
@@ -523,14 +531,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 28x28, stride = 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t Hi = 28;
|
||||
constexpr index_t Wi = 28;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -538,14 +546,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 28x28, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t Hi = 28;
|
||||
constexpr index_t Wi = 28;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -553,14 +561,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
// 1x1, 7x7
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t Hi = 7;
|
||||
constexpr index_t Wi = 7;
|
||||
constexpr index_t K = 2048;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -568,14 +576,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 7x7
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t Hi = 7;
|
||||
constexpr index_t Wi = 7;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -583,14 +591,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 1x1, 56x56
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t Hi = 56;
|
||||
constexpr index_t Wi = 56;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -598,14 +606,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 56x56
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t Hi = 56;
|
||||
constexpr index_t Wi = 56;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -613,82 +621,86 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#endif
|
||||
|
||||
auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence<N, C, HI, WI>{});
|
||||
auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence<K, C, Y, X>{});
|
||||
auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor(
|
||||
in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, LeftPads{}, RightPads{});
|
||||
constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1;
|
||||
constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1;
|
||||
|
||||
ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
|
||||
ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
|
||||
ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
|
||||
print_array("LeftPads", to_multi_index(LeftPads{}));
|
||||
print_array("RightPads", to_multi_index(RightPads{}));
|
||||
print_array("ConvStrides", to_multi_index(ConvStrides{}));
|
||||
print_array("ConvDilations", to_multi_index(ConvDilations{}));
|
||||
constexpr index_t Ho = (Hi + InLeftPads{}[0] + InRightPads{}[0] - YEff) / ConvStrides{}[0] + 1;
|
||||
constexpr index_t Wo = (Wi + InLeftPads{}[1] + InRightPads{}[1] - XEff) / ConvStrides{}[1] + 1;
|
||||
|
||||
#if 1
|
||||
using in_data_t = float;
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using in_data_t = typename vector_type<float, in_vector_size>::type;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 0
|
||||
using in_data_t = float;
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using in_data_t = typename vector_type<float, in_vector_size>::type;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = int8_t;
|
||||
#elif 1
|
||||
using in_data_t = int8_t;
|
||||
constexpr index_t in_vector_size = 16;
|
||||
using in_data_t = typename vector_type<int8_t, in_vector_size>::type;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
Tensor<in_data_t> in_nchw(make_HostTensorDescriptor(in_nchw_desc));
|
||||
Tensor<in_data_t> wei_kcyx(make_HostTensorDescriptor(wei_kcyx_desc));
|
||||
Tensor<out_data_t> out_nkhw_host(make_HostTensorDescriptor(out_nkhw_desc));
|
||||
Tensor<out_data_t> out_nkhw_device(make_HostTensorDescriptor(out_nkhw_desc));
|
||||
Tensor<in_data_t> in_nchw(HostTensorDescriptor(std::initializer_list<index_t>{N, C, Hi, Wi}));
|
||||
Tensor<in_data_t> wei_kcyx(HostTensorDescriptor(std::initializer_list<index_t>{K, C, Y, X}));
|
||||
Tensor<out_data_t> out_nkhw_host(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{N, K, Ho, Wo}));
|
||||
Tensor<out_data_t> out_nkhw_device(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{N, K, Ho, Wo}));
|
||||
|
||||
ostream_HostTensorDescriptor(in_nchw.mDesc, std::cout << "in_nchw_desc: ");
|
||||
ostream_HostTensorDescriptor(wei_kcyx.mDesc, std::cout << "wei_kcyx_desc: ");
|
||||
ostream_HostTensorDescriptor(out_nkhw_host.mDesc, std::cout << "out_nkhw_desc: ");
|
||||
|
||||
print_array("InLeftPads", InLeftPads{});
|
||||
print_array("InRightPads", InRightPads{});
|
||||
print_array("ConvStrides", ConvStrides{});
|
||||
print_array("ConvDilations", ConvDilations{});
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
if(argc != 4)
|
||||
{
|
||||
printf("arg1: do_verification, arg2: do_log, arg3: nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
bool do_verification = atoi(argv[1]);
|
||||
bool do_log = atoi(argv[2]);
|
||||
index_t nrepeat = atoi(argv[3]);
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
#if 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#elif 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
#elif 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#elif 1
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
#elif 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 1:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei_kcyx.GenerateTensorValue(gen_wei, num_thread);
|
||||
#endif
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei_kcyx.GenerateTensorValue(gen_wei, num_thread);
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
constexpr auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence<N, C, Hi, Wi>{});
|
||||
constexpr auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence<K, C, Y, X>{});
|
||||
constexpr auto out_nkhw_desc = make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
|
||||
|
||||
#if 1
|
||||
device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
@@ -697,8 +709,8 @@ int main(int argc, char* argv[])
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
InLeftPads{},
|
||||
InRightPads{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
@@ -709,8 +721,8 @@ int main(int argc, char* argv[])
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
InLeftPads{},
|
||||
InRightPads{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(in_nchw_desc,
|
||||
@@ -721,58 +733,9 @@ int main(int argc, char* argv[])
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
InLeftPads{},
|
||||
InRightPads{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 1
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
out_data_t>
|
||||
|
||||
(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 1
|
||||
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
@@ -782,8 +745,8 @@ int main(int argc, char* argv[])
|
||||
out_nkhw_host,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{});
|
||||
InLeftPads{},
|
||||
InRightPads{});
|
||||
|
||||
check_error(out_nkhw_host, out_nkhw_device);
|
||||
|
||||
|
||||
410
driver/src/conv_driver_v2.cpp
Normal file
410
driver/src/conv_driver_v2.cpp
Normal file
@@ -0,0 +1,410 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
#define USE_DYNAMIC_MODE 1
|
||||
#define USE_CONV_FWD_V4R4_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4_NHWC 1
|
||||
#define USE_CONV_FWD_V4R5_NCHW 1
|
||||
#define USE_CONV_FWD_V5R1_NCHW 0
|
||||
|
||||
enum ConvForwardAlgo
|
||||
{
|
||||
V4R4NCHW,
|
||||
V4R4NHWC,
|
||||
V4R5NCHW,
|
||||
V5R1NCHW
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
#if USE_DYNAMIC_MODE
|
||||
// dynamic mode
|
||||
if(argc != 22)
|
||||
{
|
||||
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(atoi(argv[1]));
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(atoi(argv[2]));
|
||||
const bool do_verification = atoi(argv[3]);
|
||||
const int init_method = atoi(argv[4]);
|
||||
const bool do_log = atoi(argv[5]);
|
||||
const int nrepeat = atoi(argv[6]);
|
||||
|
||||
const index_t N = atoi(argv[7]);
|
||||
const index_t K = atoi(argv[8]);
|
||||
const index_t C = atoi(argv[9]);
|
||||
const index_t Y = atoi(argv[10]);
|
||||
const index_t X = atoi(argv[11]);
|
||||
const index_t Hi = atoi(argv[12]);
|
||||
const index_t Wi = atoi(argv[13]);
|
||||
|
||||
const index_t conv_stride_h = atoi(argv[14]);
|
||||
const index_t conv_stride_w = atoi(argv[15]);
|
||||
const index_t conv_dilation_h = atoi(argv[16]);
|
||||
const index_t conv_dilation_w = atoi(argv[17]);
|
||||
const index_t in_left_pad_h = atoi(argv[18]);
|
||||
const index_t in_left_pad_w = atoi(argv[19]);
|
||||
const index_t in_right_pad_h = atoi(argv[20]);
|
||||
const index_t in_right_pad_w = atoi(argv[21]);
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
#else
|
||||
// static mode
|
||||
if(argc < 7)
|
||||
{
|
||||
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(atoi(argv[1]));
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(atoi(argv[2]));
|
||||
const bool do_verification = atoi(argv[3]);
|
||||
const int init_method = atoi(argv[4]);
|
||||
const bool do_log = atoi(argv[5]);
|
||||
const int nrepeat = atoi(argv[6]);
|
||||
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t Hi = 17;
|
||||
constexpr index_t Wi = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 7;
|
||||
|
||||
const index_t conv_stride_h = 1;
|
||||
const index_t conv_stride_w = 1;
|
||||
const index_t conv_dilation_h = 1;
|
||||
const index_t conv_dilation_w = 1;
|
||||
const index_t in_left_pad_h = 0;
|
||||
const index_t in_left_pad_w = 3;
|
||||
const index_t in_right_pad_h = 0;
|
||||
const index_t in_right_pad_w = 3;
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 1
|
||||
constexpr index_t in_vector_size = 16;
|
||||
using in_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
||||
|
||||
switch(layout)
|
||||
{
|
||||
case ConvTensorLayout::NCHW:
|
||||
// NCHW
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(Wi);
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(X);
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
||||
break;
|
||||
case ConvTensorLayout::NHWC:
|
||||
// NHWC
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(X);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Wo);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(K);
|
||||
break;
|
||||
default: throw std::runtime_error("wrong! not implemented");
|
||||
}
|
||||
|
||||
Tensor<in_data_t> in(in_lengths_host);
|
||||
Tensor<in_data_t> wei(wei_lengths_host);
|
||||
Tensor<out_data_t> out_host(out_lengths_host);
|
||||
Tensor<out_data_t> out_device(out_lengths_host);
|
||||
|
||||
std::cout << "layout: " << layout << std::endl;
|
||||
ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: ");
|
||||
ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: ");
|
||||
ostream_HostTensorDescriptor(out_host.mDesc, std::cout << "out: ");
|
||||
print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w));
|
||||
print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w));
|
||||
print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w));
|
||||
print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w));
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei.GenerateTensorValue(gen_wei, num_thread);
|
||||
}
|
||||
}
|
||||
|
||||
auto f_make_for_device_nchw = [&]() {
|
||||
#if USE_DYNAMIC_MODE
|
||||
const auto in_lengths_dev = make_tuple(N, C, Hi, Wi);
|
||||
const auto wei_lengths_dev = make_tuple(K, C, Y, X);
|
||||
const auto out_lengths_dev = make_tuple(N, K, Ho, Wo);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
#else
|
||||
const auto in_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<C>{}, Number<Hi>{}, Number<Wi>{});
|
||||
const auto wei_lengths_dev = make_tuple(Number<K>{}, Number<C>{}, Number<Y>{}, Number<X>{});
|
||||
const auto out_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<K>{}, Number<Ho>{}, Number<Wo>{});
|
||||
const auto conv_strides_dev = make_tuple(Number<conv_stride_h>{}, Number<conv_stride_w>{});
|
||||
const auto conv_dilations_dev =
|
||||
make_tuple(Number<conv_dilation_h>{}, Number<conv_dilation_w>{});
|
||||
const auto in_left_pads_dev = make_tuple(Number<in_left_pad_h>{}, Number<in_left_pad_w>{});
|
||||
const auto in_right_pads_dev =
|
||||
make_tuple(Number<in_right_pad_h>{}, Number<in_right_pad_w>{});
|
||||
#endif
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
auto f_make_for_device_nhwc = [&]() {
|
||||
#if USE_DYNAMIC_MODE
|
||||
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
|
||||
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
|
||||
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
#else
|
||||
const auto in_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<Hi>{}, Number<Wi>{}, Number<C>{});
|
||||
const auto wei_lengths_dev = make_tuple(Number<K>{}, Number<Y>{}, Number<X>{}, Number<C>{});
|
||||
const auto out_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<Ho>{}, Number<Wo>{}, Number<K>{});
|
||||
const auto conv_strides_dev = make_tuple(Number<conv_stride_h>{}, Number<conv_stride_w>{});
|
||||
const auto conv_dilations_dev =
|
||||
make_tuple(Number<conv_dilation_h>{}, Number<conv_dilation_w>{});
|
||||
const auto in_left_pads_dev = make_tuple(Number<in_left_pad_h>{}, Number<in_left_pad_w>{});
|
||||
const auto in_right_pads_dev =
|
||||
make_tuple(Number<in_right_pad_h>{}, Number<in_right_pad_w>{});
|
||||
#endif
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
const auto nhwc_desc = f_make_for_device_nhwc();
|
||||
|
||||
#if USE_CONV_FWD_V4R4_NCHW
|
||||
if(algo == ConvForwardAlgo::V4R4NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R4_NHWC
|
||||
if(algo == ConvForwardAlgo::V4R4NHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R5_NCHW
|
||||
if(algo == ConvForwardAlgo::V4R5NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V5R1_NCHW
|
||||
if(algo == ConvForwardAlgo::V5R1NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_direct_convolution(in,
|
||||
wei,
|
||||
out_host,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
layout);
|
||||
|
||||
check_error(out_host, out_device);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRange(std::cout << "in : ", in.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "wei: ", wei.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,18 +3,6 @@
|
||||
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
template <typename X>
|
||||
HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens) : mLens(lens)
|
||||
{
|
||||
this->CalculateStrides();
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens, std::vector<Y> strides)
|
||||
: mLens(lens), mStrides(strides)
|
||||
{
|
||||
}
|
||||
|
||||
void HostTensorDescriptor::CalculateStrides()
|
||||
{
|
||||
mStrides.clear();
|
||||
@@ -45,3 +33,16 @@ std::size_t HostTensorDescriptor::GetElementSpace() const
|
||||
const std::vector<std::size_t>& HostTensorDescriptor::GetLengths() const { return mLens; }
|
||||
|
||||
const std::vector<std::size_t>& HostTensorDescriptor::GetStrides() const { return mStrides; }
|
||||
|
||||
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os)
|
||||
{
|
||||
os << "dim " << desc.GetNumOfDimension() << ", ";
|
||||
|
||||
os << "lengths {";
|
||||
LogRange(os, desc.GetLengths(), ", ");
|
||||
os << "}, ";
|
||||
|
||||
os << "strides {";
|
||||
LogRange(os, desc.GetStrides(), ", ");
|
||||
os << "}" << std::endl;
|
||||
}
|
||||
|
||||
17
script/run.sh
Executable file
17
script/run.sh
Executable file
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
#make -j conv_driver
|
||||
make -j conv_driver_v2
|
||||
|
||||
LAYOUT=$1
|
||||
ALGO=$2
|
||||
VERIFY=$3
|
||||
INIT=$4
|
||||
LOG=$5
|
||||
REPEAT=$6
|
||||
|
||||
###################### layout algo verify init log repeat N__ K__ C__ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
|
||||
driver/conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
#driver/conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 384 192 3 3 35 35 2 2 1 1 0 0 0 0
|
||||
#driver/conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 1 7 17 17 1 1 1 1 0 3 0 3
|
||||
#driver/conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
Reference in New Issue
Block a user