From c866773642cf7cf8386303b6167844d93e608547 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 6 Feb 2019 23:44:21 -0600 Subject: [PATCH] unroll some loop, register double buffer gemm --- driver/conv.cu | 2 +- ...mm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh | 32 +--- src/include/blockwise_gemm.cuh | 147 ++++++++++++++++++ ...icit_gemm_convolution_2_cnhw_csrk_knhw.cuh | 13 +- 4 files changed, 158 insertions(+), 36 deletions(-) diff --git a/driver/conv.cu b/driver/conv.cu index cecf4737e8..cbb783d344 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -611,7 +611,7 @@ int main() nrepeat); #endif -#if 0 +#if 1 if(S == 3 && R == 3) { host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); diff --git a/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh b/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh index 4bf88d9edd..716e6d6691 100644 --- a/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh +++ b/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2.cuh @@ -66,42 +66,12 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InDesc, Tensor out_knhw(make_TensorDescriptor(out_knhw_desc)); -#if 0 +#if 1 // 1x1, 28x28 constexpr unsigned BPerBlock = 64; constexpr unsigned KPerBlock = 64; constexpr unsigned CPerBlock = 8; - constexpr unsigned BPerThread = 4; - constexpr unsigned KPerThread = 16; - - constexpr unsigned GemmMPerThreadSubC = 16; - constexpr unsigned GemmNPerThreadSubC = 4; - constexpr unsigned GemmMLevel0Cluster = 4; - constexpr unsigned GemmNLevel0Cluster = 8; - constexpr unsigned GemmMLevel1Cluster = 1; - constexpr unsigned GemmNLevel1Cluster = 2; - constexpr unsigned GemmKPerThreadLoop = 1; - - constexpr unsigned GemmThreadPerColumnPerCluster = 4; - constexpr unsigned GemmThreadPerRowPerCluster = 8; - - constexpr unsigned InBlockCopyThreadPerDim0 = 4; - constexpr unsigned InBlockCopyThreadPerDim1 = 16; - - constexpr unsigned WeiBlockCopyThreadPerDim0 = 4; - constexpr unsigned WeiBlockCopyThreadPerDim1 = 16; - - constexpr unsigned InBlockCopyDataPerRead = 4; - constexpr unsigned WeiBlockCopyDataPerRead = 4; - - constexpr unsigned BlockSize = 64; -#elif 1 - // 1x1, 28x28 try - constexpr unsigned BPerBlock = 64; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 8; - constexpr unsigned BPerThread = 8; constexpr unsigned KPerThread = 8; diff --git a/src/include/blockwise_gemm.cuh b/src/include/blockwise_gemm.cuh index 49ceeec168..29c0cf7695 100644 --- a/src/include/blockwise_gemm.cuh +++ b/src/include/blockwise_gemm.cuh @@ -598,9 +598,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; +#pragma unroll // loop over k for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) { +#pragma unroll // copy A-sub to form A for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) { @@ -613,6 +615,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 a_thread_sub_mtx.GetLengths()); } +#pragma unroll // copy B-sub to form B for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) { @@ -638,4 +641,148 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 f_accum); } } + + template + __device__ void Run_RegisterDoubleBuffer(FloatA* const p_a_block, + FloatB* const p_b_block, + FloatC* p_c_thread, + Accumulator f_accum) const + { + constexpr auto True = Constant{}; + constexpr auto False = Constant{}; + + const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile + const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile + const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile + + constexpr unsigned M = a_block_mtx.NCol(); + constexpr unsigned N = b_block_mtx.NCol(); + constexpr unsigned K = a_block_mtx.NRow(); + + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + // thread A, B for GEMM + const auto a_thread_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}); // constexpr doesn't compile + + const auto b_thread_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}); // constexpr doesn't compile + + // thread A-sub, B-sub for copy + const auto a_thread_sub_mtx = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); // constexpr doesn't compile + + const auto b_thread_sub_mtx = + make_ConstantMatrixDescriptor(Number{}, + Number{}, + Number{}); // constexpr doesn't compile + + FloatA p_a_thread_0[a_thread_mtx.GetElementSpace()]; + FloatB p_b_thread_0[b_thread_mtx.GetElementSpace()]; + + FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()]; + FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()]; + + constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + + constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + + constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; + constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + + // preload A, B +#pragma unroll + for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + { // copy A-sub to form A + threadwise_matrix_copy(a_block_mtx, + p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster, + a_thread_sub_mtx, + p_a_thread_0 + m_repeat * MPerThreadSubC, + a_thread_sub_mtx.GetLengths()); + } + +#pragma unroll + for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + { // copy B-sub to form B + threadwise_matrix_copy(b_block_mtx, + p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster, + b_thread_sub_mtx, + p_b_thread_0 + n_repeat * NPerThreadSubC, + b_thread_sub_mtx.GetLengths()); + } + + bool even_loop = true; + +#pragma unroll + for(unsigned k_begin = 0; k_begin + 1 < K; + k_begin += KPerThreadLoop, even_loop = !even_loop) + { // loop over k + FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1; + FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1; + + FloatA* p_a_thread_next = even_loop ? p_a_thread_1 : p_a_thread_0; + FloatB* p_b_thread_next = even_loop ? p_b_thread_1 : p_b_thread_0; + + // preload next A, B + +#pragma unroll + for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + { // copy A-sub to form A + threadwise_matrix_copy(a_block_mtx, + p_a_block + mMyThreadOffsetA + + (k_begin + 1) * a_block_mtx.RowStride() + + m_repeat * MPerLevel1Cluster, + a_thread_sub_mtx, + p_a_thread_next + m_repeat * MPerThreadSubC, + a_thread_sub_mtx.GetLengths()); + } + +#pragma unroll + for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + { // copy B-sub to form B + threadwise_matrix_copy(b_block_mtx, + p_b_block + mMyThreadOffsetB + + (k_begin + 1) * b_block_mtx.RowStride() + + n_repeat * NPerLevel1Cluster, + b_thread_sub_mtx, + p_b_thread_next + n_repeat * NPerThreadSubC, + b_thread_sub_mtx.GetLengths()); + } + + // C = A * B + threadwise_gemm(a_thread_mtx, + True, + p_a_thread_now, + b_thread_mtx, + False, + p_b_thread_now, + c_thread_mtx, + False, + p_c_thread, + f_accum); + } + + // last loop + { + even_loop = !even_loop; + + FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1; + FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1; + + // C = A * B + threadwise_gemm(a_thread_mtx, + True, + p_a_thread_now, + b_thread_mtx, + False, + p_b_thread_now, + c_thread_mtx, + False, + p_c_thread, + f_accum); + } + } }; diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh index 5d2013ea56..92b549bb37 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh @@ -237,10 +237,15 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc, { auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; - blockwise_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), - p_in_block + s * Wi + r, - p_out_thread, - f_accum); +#if 1 + blockwise_gemm.Run +#else + blockwise_gemm.Run_RegisterDoubleBuffer +#endif + (p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), + p_in_block + s * Wi + r, + p_out_thread, + f_accum); } } }