From bd811e2c2055d8da8d13d7987c40f133943de6a0 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 24 Jan 2019 16:15:51 -0600 Subject: [PATCH] refactor --- driver/conv.cu | 6 +++--- ...icit_gemm_convolution_1_nchw_srck_nkhw.cuh | 19 ++++++++++++------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/driver/conv.cu b/driver/conv.cu index d669f5b12c..97316e889d 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -9,7 +9,7 @@ #include "device_direct_convolution_1.cuh" #include "device_direct_convolution_2.cuh" //#include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh" -//#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh" +#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh" #include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh" //#include "device_winograd_convolution.cuh" @@ -418,9 +418,9 @@ int main() device_direct_convolution_2 #elif 0 device_implicit_gemm_convolution_1_nchw_kcsr -#elif 0 - device_implicit_gemm_convolution_1_nchw_srck_nkhw #elif 1 + device_implicit_gemm_convolution_1_nchw_srck_nkhw +#elif 0 device_implicit_gemm_convolution_2_cnhw_srck_knhw #elif 0 device_winograd_convolution diff --git a/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh b/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh index 04c9000a9e..d5f3308633 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh @@ -103,6 +103,15 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc, } #endif + // blockwise copy + // wei: format is [S,R,C,K], no conversion needed + constexpr auto blockwise_wei_copy = + blockwise_4d_tensor_copy_1{}; + // a series of blockwise batched GEMM // C_matrix += transpose(A_matrix) * B_matrix // A_matrix and B_matrix saved in LDS, C_matrix saved in register @@ -171,13 +180,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc, #if 1 // weight: global mem to LDS, // format is [S,R,C,K], no conversion needed - blockwise_4d_tensor_copy( - wei_srck_global_desc, - p_wei_global + - wei_srck_global_desc.Get1dIndex(0, 0, c_block_data_begin, k_block_data_begin), - wei_srck_block_desc, - p_wei_block, - wei_srck_block_desc.GetLengths()); + blockwise_wei_copy.run(p_wei_global + wei_srck_global_desc.Get1dIndex( + 0, 0, c_block_data_begin, k_block_data_begin), + p_wei_block); #endif __syncthreads();