improve implicit gemm NCHW, SRCK, NKHW, and tuned

This commit is contained in:
Chao Liu
2019-01-24 22:24:30 -06:00
parent 1de6fd0753
commit 8bd6ea1a97
4 changed files with 60 additions and 22 deletions

View File

@@ -17,6 +17,7 @@ template <unsigned GridSize,
unsigned CPerBlock,
unsigned HoPerBlock,
unsigned WoPerBlock,
unsigned NPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned HoPerThread,
@@ -32,7 +33,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
constexpr unsigned NPerThread = NPerBlock;
static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0");
static_assert((NPerThread < NPerBlock && WoPerThread == 1) || NPerThread == NPerBlock,
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
@@ -207,7 +210,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerThread;
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock;
const unsigned n_thread_data_begin =
matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock;
// output: register to global mem,
// convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo]
@@ -217,7 +222,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
out_hkwn_thread_desc,
p_out_thread,
out_nkhw_global_desc,
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin,
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),