This commit is contained in:
Chao Liu
2019-01-24 16:15:51 -06:00
parent c39c573eb8
commit bd811e2c20
2 changed files with 15 additions and 10 deletions

View File

@@ -9,7 +9,7 @@
#include "device_direct_convolution_1.cuh"
#include "device_direct_convolution_2.cuh"
//#include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh"
//#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh"
#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh"
#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
//#include "device_winograd_convolution.cuh"
@@ -418,9 +418,9 @@ int main()
device_direct_convolution_2
#elif 0
device_implicit_gemm_convolution_1_nchw_kcsr
#elif 0
device_implicit_gemm_convolution_1_nchw_srck_nkhw
#elif 1
device_implicit_gemm_convolution_1_nchw_srck_nkhw
#elif 0
device_implicit_gemm_convolution_2_cnhw_srck_knhw
#elif 0
device_winograd_convolution

View File

@@ -103,6 +103,15 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
}
#endif
// blockwise copy
// wei: format is [S,R,C,K], no conversion needed
constexpr auto blockwise_wei_copy =
blockwise_4d_tensor_copy_1<BlockSize,
Float,
decltype(wei_srck_global_desc),
decltype(wei_srck_block_desc),
decltype(wei_srck_block_desc.GetLengths())>{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
@@ -171,13 +180,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
#if 1
// weight: global mem to LDS,
// format is [S,R,C,K], no conversion needed
blockwise_4d_tensor_copy<BlockSize>(
wei_srck_global_desc,
p_wei_global +
wei_srck_global_desc.Get1dIndex(0, 0, c_block_data_begin, k_block_data_begin),
wei_srck_block_desc,
p_wei_block,
wei_srck_block_desc.GetLengths());
blockwise_wei_copy.run(p_wei_global + wei_srck_global_desc.Get1dIndex(
0, 0, c_block_data_begin, k_block_data_begin),
p_wei_block);
#endif
__syncthreads();