From abf75ac039420f7a4ab64a419416dd493b906742 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 11 Feb 2019 17:45:22 -0600 Subject: [PATCH] refactor --- driver/conv.cu | 12 ++- ...icit_gemm_convolution_1_chwn_csrk_khwn.cuh | 16 +++- src/include/ConstantMatrixDescriptor.cuh | 2 +- src/include/blockwise_gemm.cuh | 76 +++++++++---------- src/include/common.cuh | 18 ++--- ...icit_gemm_convolution_1_chwn_csrk_khwn.cuh | 18 ++--- ...mm_convolution_1_chwn_csrk_khwn_padded.cuh | 18 ++--- ...n_1_chwn_csrk_khwn_padded_lds_pipeline.cuh | 18 ++--- ..._implicit_gemm_convolution_1_nchw_kcsr.cuh | 16 ++-- ...icit_gemm_convolution_1_nchw_srck_nkhw.cuh | 16 ++-- ...icit_gemm_convolution_2_cnhw_csrk_knhw.cuh | 26 +++---- ...ion_2_cnhw_csrk_knhw_lds_double_buffer.cuh | 16 ++-- ...icit_gemm_convolution_2_cnhw_srck_knhw.cuh | 14 ++-- ...volution_2_cnhw_srck_knhw_lds_pipeline.cuh | 14 ++-- src/include/threadwise_gemm.cuh | 6 +- 15 files changed, 140 insertions(+), 146 deletions(-) diff --git a/driver/conv.cu b/driver/conv.cu index 9826230284..249e9d0562 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -391,7 +391,7 @@ int main() constexpr unsigned HPad = 0; constexpr unsigned WPad = 0; -#elif 0 +#elif 1 // 3x3, 34x34 constexpr unsigned N = 64; constexpr unsigned C = 256; @@ -430,6 +430,9 @@ int main() constexpr unsigned K = 64; constexpr unsigned S = 5; constexpr unsigned R = 5; + + constexpr unsigned HPad = 0; + constexpr unsigned WPad = 0; #elif 0 // 7x7, 38x38 constexpr unsigned N = 64; @@ -439,6 +442,9 @@ int main() constexpr unsigned K = 64; constexpr unsigned S = 7; constexpr unsigned R = 7; + + constexpr unsigned HPad = 0; + constexpr unsigned WPad = 0; #elif 0 // 3x3, 58x58 constexpr unsigned N = 16; @@ -484,7 +490,7 @@ int main() constexpr unsigned HPad = 1; constexpr unsigned WPad = 1; -#elif 1 +#elif 0 // 1x1 filter, 28x28 image constexpr unsigned N = 16; constexpr unsigned C = 256; @@ -608,7 +614,7 @@ int main() nrepeat); #endif -#if 0 +#if 1 if(S == 3 && R == 3) { host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); diff --git a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh index bf87dc1cf3..7976492690 100644 --- a/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh +++ b/driver/device_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh @@ -137,12 +137,18 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, constexpr unsigned HoPerThread = 1; constexpr unsigned WoPerThread = 1; + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; + + constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet + constexpr unsigned WeiBlockCopyDataPerRead = 4; + constexpr unsigned BlockSize = 128; -#elif 0 +#elif 1 // for 7x7, 38x38 constexpr unsigned NPerBlock = 8; constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 2; + constexpr unsigned CPerBlock = 1; constexpr unsigned HoPerBlock = 4; constexpr unsigned WoPerBlock = 4; @@ -152,6 +158,12 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, constexpr unsigned HoPerThread = 1; constexpr unsigned WoPerThread = 1; + constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; + constexpr unsigned WeiBlockCopyThreadPerDim1 = 32; + + constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet + constexpr unsigned WeiBlockCopyDataPerRead = 4; + constexpr unsigned BlockSize = 128; #elif 0 // for 3x3, 56x56 diff --git a/src/include/ConstantMatrixDescriptor.cuh b/src/include/ConstantMatrixDescriptor.cuh index 43d391762a..bf141eb4f0 100644 --- a/src/include/ConstantMatrixDescriptor.cuh +++ b/src/include/ConstantMatrixDescriptor.cuh @@ -4,7 +4,7 @@ template struct ConstantMatrixDescriptor { - __host__ __device__ ConstantMatrixDescriptor() + __host__ __device__ constexpr ConstantMatrixDescriptor() { static_assert(NCol_ <= RowStride_, "wrong! NCol > RowStride!"); } diff --git a/src/include/blockwise_gemm.cuh b/src/include/blockwise_gemm.cuh index dfcd1c4c88..ef5c092de4 100644 --- a/src/include/blockwise_gemm.cuh +++ b/src/include/blockwise_gemm.cuh @@ -124,12 +124,12 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC { if(TransA && (!TransB) && (!TransC)) { - constexpr auto True = Constant{}; - constexpr auto False = Constant{}; + constexpr auto True = integral_constant{}; + constexpr auto False = integral_constant{}; - const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile - const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile - const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile + constexpr auto a_block_mtx = BlockMatrixA{}; + constexpr auto b_block_mtx = BlockMatrixB{}; + constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed @@ -137,11 +137,11 @@ struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC constexpr unsigned NPerThread = c_thread_mtx.NCol(); // a is transposed, b is not - const auto a_thread_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}); // constexpr doesn't compile + constexpr auto a_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); - const auto b_thread_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}); // constexpr doesn't compile + constexpr auto b_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; @@ -278,8 +278,8 @@ struct BlockwiseGemmBlockABlockBThreadC if(TransA && (!TransB) && (!TransC)) { - const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile - const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile + constexpr auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile + constexpr auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), "wrong! k dimension not consistent!"); @@ -287,7 +287,7 @@ struct BlockwiseGemmBlockABlockBThreadC constexpr unsigned MPerBlock = a_block_mtx.NCol(); constexpr unsigned NPerBlock = b_block_mtx.NCol(); - const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile + constexpr auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile // divide thread work constexpr unsigned MPerThread = c_thread_mtx.NRow(); @@ -374,12 +374,12 @@ struct BlockwiseGemmBlockABlockBThreadC { if(TransA && (!TransB) && (!TransC)) { - constexpr auto True = Constant{}; - constexpr auto False = Constant{}; + constexpr auto True = integral_constant{}; + constexpr auto False = integral_constant{}; - const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile - const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile - const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile + constexpr auto a_block_mtx = BlockMatrixA{}; + constexpr auto b_block_mtx = BlockMatrixB{}; + constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed @@ -387,11 +387,11 @@ struct BlockwiseGemmBlockABlockBThreadC constexpr unsigned NPerThread = c_thread_mtx.NCol(); // a is transposed, b is not - const auto a_thread_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}); // constexpr doesn't compile + constexpr auto a_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); - const auto b_thread_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}); // constexpr doesn't compile + constexpr auto b_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; @@ -556,8 +556,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 FloatC* p_c_thread, Accumulator f_accum) const { - constexpr auto True = Constant{}; - constexpr auto False = Constant{}; + constexpr auto True = integral_constant{}; + constexpr auto False = integral_constant{}; const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile @@ -648,12 +648,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 FloatC* p_c_thread, Accumulator f_accum) const { - constexpr auto True = Constant{}; - constexpr auto False = Constant{}; + constexpr auto True = integral_constant{}; + constexpr auto False = integral_constant{}; - const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile - const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile - const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile + constexpr auto a_block_mtx = BlockMatrixA{}; + constexpr auto b_block_mtx = BlockMatrixB{}; + constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr unsigned M = a_block_mtx.NCol(); constexpr unsigned N = b_block_mtx.NCol(); @@ -663,22 +663,18 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 constexpr unsigned NPerThread = c_thread_mtx.NCol(); // thread A, B for GEMM - const auto a_thread_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}); // constexpr doesn't compile + constexpr auto a_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); - const auto b_thread_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}); // constexpr doesn't compile + constexpr auto b_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); // thread A-sub, B-sub for copy - const auto a_thread_sub_mtx = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); // constexpr doesn't compile + constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); - const auto b_thread_sub_mtx = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); // constexpr doesn't compile + constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); // register FloatA p_a_thread_0[a_thread_mtx.GetElementSpace()]; diff --git a/src/include/common.cuh b/src/include/common.cuh index 9885cabf84..461ea6fe11 100644 --- a/src/include/common.cuh +++ b/src/include/common.cuh @@ -1,6 +1,8 @@ #pragma once -#define WARPSIZE 32; +__device__ unsigned get_thread_local_1d_id() { return threadIdx.x; } + +__device__ unsigned get_block_1d_id() { return blockIdx.x; } template struct is_same @@ -14,20 +16,16 @@ struct is_same static const bool value = true; }; -__device__ unsigned get_thread_local_1d_id() { return threadIdx.x; } - -__device__ unsigned get_block_1d_id() { return blockIdx.x; } - template -struct Constant +struct integral_constant { - static const T mValue = N; + static const T value = N; - __host__ __device__ constexpr T Get() const { return mValue; } + __host__ __device__ constexpr T Get() const { return value; } }; template -using Number = Constant; +using Number = integral_constant; template struct Sequence @@ -64,4 +62,4 @@ struct Sequence printf("Sequence::ReorderByPutOldToNew not implemented"); assert(false); } -}; +}; \ No newline at end of file diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh index 30a8dda2a2..326e5939ef 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn.cuh @@ -156,18 +156,16 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc, // A_matrix[C,K] is a sub-matrix of wei_block[S,R,C,K] // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] // C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N] - const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, - Number{}, - Number{}); // constexpr doesn't compile + constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); - const auto b_cxwn_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, - Number{}, - Number{}); // constexpr doesn't compile + constexpr auto b_cxwn_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); - const auto c_kxwn_thread_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}); // constexpr doesn't compile + constexpr auto c_kxwn_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, Number{}); const auto blockwise_batch_gemm = Blockwise1dStridedBatchedGemmBlockABlockBThreadC{}, - Number{}, - Number{}); // constexpr doesn't compile + constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); - const auto b_cxwn_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, - Number{}, - Number{}); // constexpr doesn't compile + constexpr auto b_cxwn_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); - const auto c_kxwn_thread_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}); // constexpr doesn't compile + constexpr auto c_kxwn_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, Number{}); const auto blockwise_batch_gemm = Blockwise1dStridedBatchedGemmBlockABlockBThreadC{}, - Number{}, - Number{}); // constexpr doesn't compile + constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); - const auto b_cxwn_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, - Number{}, - Number{}); // constexpr doesn't compile + constexpr auto b_cxwn_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); - const auto c_kxwn_thread_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}); // constexpr doesn't compile + constexpr auto c_kxwn_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, Number{}); const auto blockwise_batch_gemm = Blockwise1dStridedBatchedGemmBlockABlockBThreadC{}, Number{}); // constexpr doesn't compile + constexpr auto a_cxk_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, Number{}); - const auto b_cxwn_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, - Number{}, - Number{}); // constexpr doesn't compile + constexpr auto b_cxwn_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); - const auto c_kxwn_thread_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}); // constexpr doesn't compile + constexpr auto c_kxwn_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, Number{}); const auto blockwise_batch_gemm = Blockwise1dStridedBatchedGemmBlockABlockBThreadC{}, Number{}); // constexpr doesn't compile + constexpr auto a_cxk_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, Number{}); - const auto b_cxwn_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, - Number{}, - Number{}); // constexpr doesn't compile + constexpr auto b_cxwn_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); - const auto c_kxwn_thread_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}); // constexpr doesn't compile + constexpr auto c_kxwn_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, Number{}); const auto blockwise_batch_gemm = Blockwise1dStridedBatchedGemmBlockABlockBThreadC{}, - Number{}, - Number{}); // constexpr doesn't compile + constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); - const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, - Number{}, - Number{}); // constexpr doesn't compile + constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); - const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}); // constexpr doesn't compile + constexpr auto c_kxb_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, Number{}); #if 0 const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC{}, - Number{}, - Number{}); // constexpr doesn't compile + constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); - const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, - Number{}, - Number{}); // constexpr doesn't compile + constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); - const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}); // constexpr doesn't compile + constexpr auto c_kxb_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, Number{}); #if 0 const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC{}, Number{}); // constexpr doesn't compile + constexpr auto a_cxk_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, Number{}); - const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, - Number{}, - Number{}); // constexpr doesn't compile + constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); - const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}); // constexpr doesn't compile + constexpr auto c_kxb_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, Number{}); const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC{}, Number{}); // constexpr doesn't compile + constexpr auto a_cxk_block_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, Number{}); - const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, - Number{}, - Number{}); // constexpr doesn't compile + constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); - const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}); // constexpr doesn't compile + constexpr auto c_kxb_thread_mtx_desc = + make_ConstantMatrixDescriptor(Number{}, Number{}); const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC __device__ void threadwise_gemm(MatrixA, - Constant, + integral_constant, FloatA* const p_a_thread, MatrixB, - Constant, + integral_constant, FloatB* const p_b_thread, MatrixC, - Constant, + integral_constant, FloatC* p_c_thread, Accumulator f_accum) {