From 5872b710df80d0ebb85b8ea7d289a1f705ee4afb Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 16 Jan 2019 15:45:02 -0600 Subject: [PATCH] refactor --- driver/conv.cu | 21 +-- ...e_implicit_gemm_convolution_nchw_kcsr.cuh} | 49 +++--- ...ce_implicit_gemm_convolution_nchw_srck.cuh | 140 ++++++++++++++++++ src/include/blockwise_tensor_op.cuh | 14 -- ...se_implicit_gemm_convolution_nchw_kcsr.cuh | 9 +- ...se_implicit_gemm_convolution_nchw_srck.cuh | 9 +- 6 files changed, 171 insertions(+), 71 deletions(-) rename driver/{device_implicit_gemm_convolution.cuh => device_implicit_gemm_convolution_nchw_kcsr.cuh} (68%) create mode 100644 driver/device_implicit_gemm_convolution_nchw_srck.cuh diff --git a/driver/conv.cu b/driver/conv.cu index 0eb93bf1c3..db1259140d 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -8,7 +8,8 @@ #include "conv_common.cuh" #include "device_direct_convolution_1.cuh" #include "device_direct_convolution_2.cuh" -#include "device_implicit_gemm_convolution.cuh" +#include "device_implicit_gemm_convolution_nchw_kcsr.cuh" +#include "device_implicit_gemm_convolution_nchw_srck.cuh" //#include "device_winograd_convolution.cuh" struct GeneratorTensor_1 @@ -393,18 +394,6 @@ int main() wei_kcsr.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); #endif -#if 1 - auto wei_srck_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(wei_srck_desc, std::cout << "wei_srck_desc: "); - Tensor wei_srck(make_TensorDescriptor(wei_srck_desc)); - - auto f_reorder_kcsr2srck = [&](auto k, auto c, auto s, auto r) { - wei_srck(s, r, c, k) = wei_kcsr(k, c, s, r); - }; - - make_ParallelTensorFunctor(f_reorder_kcsr2srck, K, C, S, R)(num_thread); -#endif - for(int i = 0; i < 40; ++i) { #if 0 @@ -413,11 +402,11 @@ int main() device_direct_convolution_2( in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device); #elif 0 - device_implicit_gemm_convolution( + device_implicit_gemm_convolution_nchw_kcsr( in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device); #elif 1 - device_implicit_gemm_convolution( - in_nchw_desc, in_nchw, wei_srck_desc, wei_srck, out_nkhw_desc, out_nkhw_device); + device_implicit_gemm_convolution_nchw_srck( + in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device); #elif 0 device_winograd_convolution( in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device); diff --git a/driver/device_implicit_gemm_convolution.cuh b/driver/device_implicit_gemm_convolution_nchw_kcsr.cuh similarity index 68% rename from driver/device_implicit_gemm_convolution.cuh rename to driver/device_implicit_gemm_convolution_nchw_kcsr.cuh index 1fe3860d6a..5e5a9adebf 100644 --- a/driver/device_implicit_gemm_convolution.cuh +++ b/driver/device_implicit_gemm_convolution_nchw_kcsr.cuh @@ -1,9 +1,8 @@ #pragma once #include "gridwise_implicit_gemm_convolution_nchw_kcsr.cuh" -#include "gridwise_implicit_gemm_convolution_nchw_srck.cuh" template -void device_implicit_gemm_convolution( +void device_implicit_gemm_convolution_nchw_kcsr( InDesc, const Tensor& in, WeiDesc, const Tensor& wei, OutDesc, Tensor& out) { std::size_t data_sz = sizeof(T); @@ -82,31 +81,27 @@ void device_implicit_gemm_convolution( cudaEventCreate(&start); cudaEventRecord(start, 0); -#if 0 - gridwise_implicit_gemm_convolution_nchw_kcsr -#elif 1 - gridwise_implicit_gemm_convolution_nchw_srck -#endif - <<>>(InDesc{}, - static_cast(in_device_buf.GetDeviceBuffer()), - WeiDesc{}, - static_cast(wei_device_buf.GetDeviceBuffer()), - OutDesc{}, - static_cast(out_device_buf.GetDeviceBuffer())); + gridwise_implicit_gemm_convolution_nchw_kcsr + <<>>(InDesc{}, + static_cast(in_device_buf.GetDeviceBuffer()), + WeiDesc{}, + static_cast(wei_device_buf.GetDeviceBuffer()), + OutDesc{}, + static_cast(out_device_buf.GetDeviceBuffer())); cudaEventCreate(&stop); cudaEventRecord(stop, 0); diff --git a/driver/device_implicit_gemm_convolution_nchw_srck.cuh b/driver/device_implicit_gemm_convolution_nchw_srck.cuh new file mode 100644 index 0000000000..adb41a8c23 --- /dev/null +++ b/driver/device_implicit_gemm_convolution_nchw_srck.cuh @@ -0,0 +1,140 @@ +#pragma once +#include "gridwise_implicit_gemm_convolution_nchw_srck.cuh" + +template +void device_implicit_gemm_convolution_nchw_srck(InDesc, + const Tensor& in_nchw, + WeiDesc, + const Tensor& wei_kcsr, + OutDesc, + Tensor& out_nkhw) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_nchw_desc = InDesc{}; + constexpr auto wei_kcsr_desc = WeiDesc{}; + constexpr auto out_nkhw_desc = OutDesc{}; + + constexpr unsigned N = out_nkhw_desc.GetLength(I0); + constexpr unsigned Ho = out_nkhw_desc.GetLength(I2); + constexpr unsigned Wo = out_nkhw_desc.GetLength(I3); + + constexpr unsigned K = wei_kcsr_desc.GetLength(I0); + constexpr unsigned C = wei_kcsr_desc.GetLength(I1); + constexpr unsigned S = wei_kcsr_desc.GetLength(I2); + constexpr unsigned R = wei_kcsr_desc.GetLength(I3); + + auto wei_srck_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(wei_srck_desc, std::cout << "wei_srck_desc: "); + + Tensor wei_srck(make_TensorDescriptor(wei_srck_desc)); + + auto f_reorder_kcsr2srck = [&](auto k, auto c, auto s, auto r) { + wei_srck(s, r, c, k) = wei_kcsr(k, c, s, r); + }; + + make_ParallelTensorFunctor(f_reorder_kcsr2srck, K, C, S, R)( + std::thread::hardware_concurrency()); + + std::size_t data_sz = sizeof(T); + DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); + DeviceMem wei_srck_device_buf(data_sz * wei_srck.mDesc.GetElementSpace()); + DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); + + int num_thread = std::thread::hardware_concurrency(); + + in_nchw_device_buf.ToDevice(in_nchw.mData.data()); + wei_srck_device_buf.ToDevice(wei_srck.mData.data()); + out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); + +#if 0 + constexpr unsigned NPerBlock = 1; + constexpr unsigned KPerBlock = 1; + constexpr unsigned CPerBlock = 1; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 32; + + constexpr unsigned KPerThread = 1; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 2; + constexpr unsigned WoPerThread = 2; + + constexpr unsigned BlockSize = 16; +#elif 1 + constexpr unsigned NPerBlock = 2; + constexpr unsigned KPerBlock = 32; + constexpr unsigned CPerBlock = 4; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 32; + + constexpr unsigned KPerThread = 4; + constexpr unsigned CPerThread = 2; + constexpr unsigned HoPerThread = 2; + constexpr unsigned WoPerThread = 2; + + constexpr unsigned BlockSize = 128; +#elif 0 + constexpr unsigned NPerBlock = 2; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 4; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 32; + + constexpr unsigned KPerThread = 4; + constexpr unsigned CPerThread = 2; + constexpr unsigned HoPerThread = 2; + constexpr unsigned WoPerThread = 2; + + constexpr unsigned BlockSize = 256; +#endif + + constexpr unsigned GridSize = + ((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) * + ((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock); + + dim3 block_dim(BlockSize); + dim3 grid_dim(GridSize); + + printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); + + cudaEvent_t start, stop; + float elapsedTime; + + cudaEventCreate(&start); + cudaEventRecord(start, 0); + + gridwise_implicit_gemm_convolution_nchw_srck + <<>>(in_nchw_desc, + static_cast(in_nchw_device_buf.GetDeviceBuffer()), + wei_srck_desc, + static_cast(wei_srck_device_buf.GetDeviceBuffer()), + out_nkhw_desc, + static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + + cudaEventCreate(&stop); + cudaEventRecord(stop, 0); + cudaEventSynchronize(stop); + + cudaEventElapsedTime(&elapsedTime, start, stop); + printf("Elapsed time : %f ms\n", elapsedTime); + + checkCudaErrors(cudaGetLastError()); + out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); +} diff --git a/src/include/blockwise_tensor_op.cuh b/src/include/blockwise_tensor_op.cuh index 8d2426ba4a..13e2093333 100644 --- a/src/include/blockwise_tensor_op.cuh +++ b/src/include/blockwise_tensor_op.cuh @@ -135,20 +135,6 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]); -#if 0 - printf("did %u %u %u %u, did_IR %u %u %u %u, index %u %u\n", - did[0], - did[1], - did[2], - did[3], - did[IR0], - did[IR1], - did[IR2], - did[IR3], - aindex, - bindex); -#endif - f(p_src[aindex], p_dst[bindex]); } diff --git a/src/include/gridwise_implicit_gemm_convolution_nchw_kcsr.cuh b/src/include/gridwise_implicit_gemm_convolution_nchw_kcsr.cuh index 67becb5405..6b83a3d51f 100644 --- a/src/include/gridwise_implicit_gemm_convolution_nchw_kcsr.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_nchw_kcsr.cuh @@ -157,7 +157,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, c_block_data_begin += CPerBlock, __syncthreads()) { // input: global mem to LDS, - // convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N] + // convert [N,C,Hi,Wi] to [C,Hi,Wi,N] blockwise_4d_tensor_copy_reorder_by_get_dst_from_src( in_nchw_global_desc, p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin, @@ -170,6 +170,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, reorder_chwn_from_nchw); // weight: global mem to LDS, + // convert [K,C,S,R] to [S,R,C,K] blockwise_4d_tensor_copy_reorder_by_get_dst_from_src( wei_kcsr_global_desc, p_wei_global + @@ -217,10 +218,4 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, wo_block_data_begin + wo_thread_data_begin), out_hkwn_thread_desc.GetLengths(), reorder_nkhw_from_hkwn); - - // printf("%f %f %f %f\n", p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); - // printf("%u %u, %u %u %u\n", get_block_1d_id(), get_thread_local_1d_id(), - // matrix_c_index.batch_begin, matrix_c_index.row_begin, matrix_c_index.col_begin); printf("%u - // %u, %u %u %u\n", get_block_1d_id(), get_thread_local_1d_id(), ho_thread_data_begin, - // k_thread_data_begin, wo_thread_data_begin); } diff --git a/src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh b/src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh index c495f5ecb5..c08d026d33 100644 --- a/src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh @@ -153,7 +153,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc, c_block_data_begin += CPerBlock, __syncthreads()) { // input: global mem to LDS, - // convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N] + // convert [N,C,Hi,Wi] to [C,Hi,Wi,N] blockwise_4d_tensor_copy_reorder_by_get_dst_from_src( in_nchw_global_desc, p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin, @@ -166,6 +166,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc, reorder_chwn_from_nchw); // weight: global mem to LDS, + // format is [S,R,C,K], no conversion needed blockwise_4d_tensor_copy( wei_srck_global_desc, p_wei_global + @@ -212,10 +213,4 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc, wo_block_data_begin + wo_thread_data_begin), out_hkwn_thread_desc.GetLengths(), reorder_nkhw_from_hkwn); - - // printf("%f %f %f %f\n", p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); - // printf("%u %u, %u %u %u\n", get_block_1d_id(), get_thread_local_1d_id(), - // matrix_c_index.batch_begin, matrix_c_index.row_begin, matrix_c_index.col_begin); printf("%u - // %u, %u %u %u\n", get_block_1d_id(), get_thread_local_1d_id(), ho_thread_data_begin, - // k_thread_data_begin, wo_thread_data_begin); }