From 89ee259752fe94c74ad894496cb8cf71276ea43a Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 16 Jan 2019 02:44:10 -0600 Subject: [PATCH] adding implicit gemm --- driver/conv.cu | 20 +++-------- driver/device_implicit_gemm_convolution.cuh | 19 +++++++--- src/include/blockwise_tensor_op.cuh | 2 +- src/include/gemm.cuh | 3 -- ...se_implicit_gemm_convolution_nchw_kcsr.cuh | 35 +++++++++++-------- ...se_implicit_gemm_convolution_nchw_srck.cuh | 4 +-- 6 files changed, 44 insertions(+), 39 deletions(-) diff --git a/driver/conv.cu b/driver/conv.cu index 64a4ceb714..7b53422200 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -343,7 +343,7 @@ int main() constexpr unsigned K = 1; constexpr unsigned S = 3; constexpr unsigned R = 3; -#elif 1 +#elif 0 constexpr unsigned N = 1; constexpr unsigned C = 1; constexpr unsigned HI = 34; @@ -396,19 +396,14 @@ int main() #if 0 in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread); -#elif 0 +#elif 1 in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcsr.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); -#elif 0 - in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei_kcsr.GenerateTensorValue(GeneratorTensor_3{}, num_thread); -#elif 1 - in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread); - wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread); #endif #if 1 auto wei_srck_desc = make_ConstantTensorDescriptor(Sequence{}); + ostream_ConstantTensorDescriptor(wei_srck_desc, std::cout << "wei_srck_desc: "); Tensor wei_srck(make_TensorDescriptor(wei_srck_desc)); auto f_reorder_kcsr2srck = [&](auto k, auto c, auto s, auto r) { @@ -418,12 +413,7 @@ int main() make_ParallelTensorFunctor(f_reorder_kcsr2srck, K, C, S, R)(num_thread); #endif -#if 0 - wei_srck.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - out_nkhw_device.GenerateTensorValue(GeneratorTensor_1{}, num_thread); -#endif - - for(int i = 0; i < 1; ++i) + for(int i = 0; i < 40; ++i) { #if 0 device_direct_convolution_1(in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device); @@ -450,7 +440,7 @@ int main() check_error(out_nkhw_host, out_nkhw_device); #endif -#if 1 +#if 0 LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; LogRange(std::cout << "wei_kcsr: ", wei_kcsr.mData, ",") << std::endl; LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; diff --git a/driver/device_implicit_gemm_convolution.cuh b/driver/device_implicit_gemm_convolution.cuh index 1f776a2974..515717df9e 100644 --- a/driver/device_implicit_gemm_convolution.cuh +++ b/driver/device_implicit_gemm_convolution.cuh @@ -1,5 +1,5 @@ #pragma once -//#include "gridwise_implicit_gemm_convolution_nchw_kcsr.cuh" +#include "gridwise_implicit_gemm_convolution_nchw_kcsr.cuh" #include "gridwise_implicit_gemm_convolution_nchw_srck.cuh" template @@ -26,14 +26,13 @@ void device_implicit_gemm_convolution( constexpr auto wei_desc = WeiDesc{}; constexpr auto out_desc = OutDesc{}; -#if 1 +#if 0 constexpr unsigned NPerBlock = 1; constexpr unsigned KPerBlock = 1; constexpr unsigned CPerBlock = 1; constexpr unsigned HoPerBlock = 2; constexpr unsigned WoPerBlock = 32; - constexpr unsigned NPerThread = 1; constexpr unsigned KPerThread = 1; constexpr unsigned CPerThread = 1; constexpr unsigned HoPerThread = 2; @@ -47,13 +46,25 @@ void device_implicit_gemm_convolution( constexpr unsigned HoPerBlock = 2; constexpr unsigned WoPerBlock = 32; - constexpr unsigned NPerThread = 2; constexpr unsigned KPerThread = 4; constexpr unsigned CPerThread = 2; constexpr unsigned HoPerThread = 2; constexpr unsigned WoPerThread = 2; constexpr unsigned BlockSize = 128; +#elif 1 + constexpr unsigned NPerBlock = 2; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 4; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 32; + + constexpr unsigned KPerThread = 4; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 2; + constexpr unsigned WoPerThread = 2; + + constexpr unsigned BlockSize = 256; #endif constexpr unsigned GridSize = diff --git a/src/include/blockwise_tensor_op.cuh b/src/include/blockwise_tensor_op.cuh index ed5b080e0a..8d2426ba4a 100644 --- a/src/include/blockwise_tensor_op.cuh +++ b/src/include/blockwise_tensor_op.cuh @@ -135,7 +135,7 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]); -#if 1 +#if 0 printf("did %u %u %u %u, did_IR %u %u %u %u, index %u %u\n", did[0], did[1], diff --git a/src/include/gemm.cuh b/src/include/gemm.cuh index 0a3789580f..507f08f51b 100644 --- a/src/include/gemm.cuh +++ b/src/include/gemm.cuh @@ -145,9 +145,6 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c static_assert(BlockSize == BThreadWork * MThreadWork * NThreadWork, "wrong! wrong BlockSize"); - // printf("%u %u, %u %u\n", get_block_1d_id(), get_thread_local_1d_id(), MThreadWork, - // NThreadWork); - if(DistributeThreadAlongColumnFirst) { // num of operations can be reduced diff --git a/src/include/gridwise_implicit_gemm_convolution_nchw_kcsr.cuh b/src/include/gridwise_implicit_gemm_convolution_nchw_kcsr.cuh index 08f98fce0b..662e839b52 100644 --- a/src/include/gridwise_implicit_gemm_convolution_nchw_kcsr.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_nchw_kcsr.cuh @@ -82,25 +82,24 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, make_ConstantTensorDescriptor(Sequence{}); // tensor view of reordered blockwise input and weight in LDS - constexpr auto reorder_chwn_from_nchw = Sequence<1, 2, 3, 0>{}; - constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor( - in_nchw_block_desc.GetLengths().ReorderByGetNewFromOld(reorder_chwn_from_nchw)); - constexpr auto reorder_srck_from_kcsr = Sequence<2, 3, 1, 0>{}; constexpr auto wei_srck_block_desc = make_ConstantTensorDescriptor( wei_kcsr_block_desc.GetLengths().ReorderByGetNewFromOld(reorder_srck_from_kcsr)); + constexpr auto reorder_chwn_from_nchw = Sequence<1, 2, 3, 0>{}; + constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor( + in_nchw_block_desc.GetLengths().ReorderByGetNewFromOld(reorder_chwn_from_nchw)); + // tensor view of threadwise output in register constexpr auto out_hkwn_thread_desc = make_ConstantTensorDescriptor(Sequence{}); -#if 0 +#if 1 if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) { print_ConstantTensorDescriptor(in_nchw_block_desc, "in_nchw_block_desc"); print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc"); - print_ConstantTensorDescriptor(wei_kcsr_block_desc, "wei_kcsr_block_desc"); print_ConstantTensorDescriptor(wei_srck_block_desc, "wei_srck_block_desc"); print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc"); @@ -124,8 +123,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, const auto c_kxwn_thread_mtx_desc = make_ConstantMatrixDescriptor( Number{}, Number{}); // constexpr doesn't compile - auto f_accum = [](auto& c, auto& ab) { c += ab; }; - const auto blockwise_batch_gemm = blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c{}; + true>{}; // LDS constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace(); @@ -176,7 +173,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, #if 1 // weight: global mem to LDS, - // convert 4d-tensor wei[K,C,S,R] to matrix wei_matrix[S*R*C,K] blockwise_4d_tensor_copy_reorder_by_get_dst_from_src( wei_kcsr_global_desc, p_wei_global + @@ -189,24 +185,29 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, __syncthreads(); +#if 1 // a series of batched GEMM for(unsigned s = 0; s < S; ++s) { for(unsigned r = 0; r < R; ++r) { + auto f_accum = [](auto& c, const auto&& ab) { c += ab; }; + blockwise_batch_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0), p_in_block + in_chwn_block_desc.Get1dIndex(0, 0, r, 0), - p_out_thread); + p_out_thread, + f_accum); } } +#endif } const auto matrix_c_index = blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; - const unsigned k_thread_data_begin = matrix_c_index.col_begin; - const unsigned wo_thread_data_begin = matrix_c_index.row_begin / NPerThread; + const unsigned k_thread_data_begin = matrix_c_index.row_begin; + const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerThread; // output: register to global mem, // convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo] @@ -222,4 +223,10 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, wo_block_data_begin + wo_thread_data_begin), out_hkwn_thread_desc.GetLengths(), reorder_nkhw_from_hkwn); + + // printf("%f %f %f %f\n", p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); + // printf("%u %u, %u %u %u\n", get_block_1d_id(), get_thread_local_1d_id(), + // matrix_c_index.batch_begin, matrix_c_index.row_begin, matrix_c_index.col_begin); printf("%u + // %u, %u %u %u\n", get_block_1d_id(), get_thread_local_1d_id(), ho_thread_data_begin, + // k_thread_data_begin, wo_thread_data_begin); } diff --git a/src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh b/src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh index 1a546d6c97..1874c251ac 100644 --- a/src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh @@ -90,7 +90,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc, constexpr auto out_hkwn_thread_desc = make_ConstantTensorDescriptor(Sequence{}); -#if 1 +#if 0 if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) { print_ConstantTensorDescriptor(in_nchw_block_desc, "in_nchw_block_desc"); @@ -189,7 +189,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc, auto f_accum = [](auto& c, const auto&& ab) { c += ab; }; blockwise_batch_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0), - p_in_block + in_chwn_block_desc.Get1dIndex(0, 0, r, 0), + p_in_block + in_chwn_block_desc.Get1dIndex(0, s, r, 0), p_out_thread, f_accum); }