From b5b4fd28eda1102555ec02be080ec0eaf5c8762d Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 21 Jan 2019 15:33:34 -0600 Subject: [PATCH] refactor --- driver/conv.cu | 6 +- ...icit_gemm_convolution_1_nchw_srck_nkhw.cuh | 45 ++-- ...icit_gemm_convolution_2_cnhw_srck_knhw.cuh | 26 +-- src/include/gemm.cuh | 210 +++++++++++++++++- ...icit_gemm_convolution_1_nchw_srck_nkhw.cuh | 10 +- ...icit_gemm_convolution_2_cnhw_srck_knhw.cuh | 49 ++-- 6 files changed, 271 insertions(+), 75 deletions(-) diff --git a/driver/conv.cu b/driver/conv.cu index 42a9e950e5..feb665d96c 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -354,10 +354,10 @@ int main() { #if 0 constexpr unsigned N = 1; - constexpr unsigned C = 2; + constexpr unsigned C = 1; constexpr unsigned HI = 34; constexpr unsigned WI = 34; - constexpr unsigned K = 2; + constexpr unsigned K = 4; constexpr unsigned S = 3; constexpr unsigned R = 3; #elif 1 @@ -418,7 +418,7 @@ int main() device_direct_convolution_2 #elif 0 device_implicit_gemm_convolution_1_nchw_kcsr -#elif 1 +#elif 0 device_implicit_gemm_convolution_1_nchw_srck_nkhw #elif 1 device_implicit_gemm_convolution_2_cnhw_srck_knhw diff --git a/driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh b/driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh index 293b46b5a1..39a2573de2 100644 --- a/driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh +++ b/driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh @@ -4,12 +4,12 @@ template void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcsr, - OutDesc, - Tensor& out_nkhw, - unsigned nrepeat) + const Tensor& in_nchw, + WeiDesc, + const Tensor& wei_kcsr, + OutDesc, + Tensor& out_nkhw, + unsigned nrepeat) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -104,7 +104,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc, constexpr unsigned WoPerThread = 1; constexpr unsigned BlockSize = 128; -#elif 1 +#elif 0 constexpr unsigned NPerBlock = 2; constexpr unsigned KPerBlock = 32; constexpr unsigned CPerBlock = 4; @@ -137,20 +137,20 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc, cudaEventRecord(start, 0); gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw + BlockSize, + T, + decltype(in_nchw_desc), + decltype(wei_srck_desc), + decltype(out_nkhw_desc), + NPerBlock, + KPerBlock, + CPerBlock, + HoPerBlock, + WoPerBlock, + KPerThread, + CPerThread, + HoPerThread, + WoPerThread> <<>>(in_nchw_desc, static_cast(in_nchw_device_buf.GetDeviceBuffer()), wei_srck_desc, @@ -165,10 +165,9 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc, cudaEventElapsedTime(&elapsedTime, start, stop); printf("Elapsed time : %f ms\n", elapsedTime); - usleep(10); + usleep(10000); } - checkCudaErrors(cudaGetLastError()); out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); } diff --git a/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh b/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh index 21b5c3b43e..9585113fb9 100644 --- a/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh +++ b/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh @@ -1,5 +1,6 @@ #pragma once #include "gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh" +#include template void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, @@ -67,35 +68,29 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, #if 0 constexpr unsigned BPerBlock = 128; - constexpr unsigned KPerBlock = 1; + constexpr unsigned KPerBlock = 4; constexpr unsigned CPerBlock = 1; constexpr unsigned BPerThread = 4; constexpr unsigned KPerThread = 1; constexpr unsigned CPerThread = 1; - constexpr unsigned BlockSize = 32; -#elif 0 - constexpr unsigned BPerBlock = 128; - constexpr unsigned KPerBlock = 2; - constexpr unsigned CPerBlock = 2; + constexpr unsigned ThreadPerClusterRow = 4; + constexpr unsigned ThreadPerClusterColumn = 16; - constexpr unsigned BPerThread = 4; - constexpr unsigned KPerThread = 2; - constexpr unsigned CPerThread = 1; - - constexpr unsigned BlockSize = 32; + constexpr unsigned BlockSize = 128; #elif 1 constexpr unsigned BPerBlock = 128; constexpr unsigned KPerBlock = 64; constexpr unsigned CPerBlock = 2; - constexpr unsigned BPerBatch = 32; - constexpr unsigned BPerThread = 4; constexpr unsigned KPerThread = 16; constexpr unsigned CPerThread = 1; + constexpr unsigned ThreadPerClusterRow = 4; + constexpr unsigned ThreadPerClusterColumn = 16; + constexpr unsigned BlockSize = 128; #endif @@ -137,7 +132,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, BPerThread, KPerThread, CPerThread, - BPerBatch> + ThreadPerClusterRow, + ThreadPerClusterColumn> <<>>(in_cnhw_desc, static_cast(in_cnhw_device_buf.GetDeviceBuffer()), wei_srck_desc, @@ -151,6 +147,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, cudaEventElapsedTime(&elapsedTime, start, stop); printf("Elapsed time : %f ms\n", elapsedTime); + + usleep(10000); } checkCudaErrors(cudaGetLastError()); diff --git a/src/include/gemm.cuh b/src/include/gemm.cuh index 62b2625bff..760cc1ad4d 100644 --- a/src/include/gemm.cuh +++ b/src/include/gemm.cuh @@ -156,11 +156,11 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c static_assert(MPerBlock % MPerThread == 0, "MPerBlock % MPerThread != 0"); static_assert(NPerBlock % NPerThread == 0, "NPerBlock % NPerThread != 0"); - constexpr unsigned BThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread; - constexpr unsigned MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread; - constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; + constexpr unsigned BatchThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread; + constexpr unsigned MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread; + constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; - static_assert(BlockSize == BThreadWork * MThreadWork * NThreadWork, + static_assert(BlockSize == BatchThreadWork * MThreadWork * NThreadWork, "wrong! wrong BlockSize"); if(DistributeThreadAlongColumnFirst) @@ -289,3 +289,205 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c } } }; + +template +struct blockwise_gemm_block_a_block_b_thread_c +{ + unsigned mMyThreadOffsetA = 0; + unsigned mMyThreadOffsetB = 0; + + struct MatrixIndex + { + unsigned row_begin; + unsigned col_begin; + }; + + __device__ blockwise_gemm_block_a_block_b_thread_c() + { + const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile + const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile + + const auto c_thread_mtx_index = CalculateThreadMatrixCIndex(get_thread_local_1d_id()); + + mMyThreadOffsetA = (!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row_begin, 0) + : a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row_begin); + + mMyThreadOffsetB = (!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col_begin) + : b_block_mtx.Get1dIndex(c_thread_mtx_index.col_begin, 0); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: "); + print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: "); + print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: "); + + printf("%u %u, %u %u %u, %u %u\n", + get_block_1d_id(), + get_thread_local_1d_id(), + c_thread_mtx_index.batch_begin, + c_thread_mtx_index.row_begin, + c_thread_mtx_index.col_begin, + mMyThreadOffsetA, + mMyThreadOffsetB); + } +#endif + } + + __device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const + { + + if(TransA && (!TransB) && (!TransC)) + { + const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile + const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile + + static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), + "wrong! k dimension not consistent!"); + + constexpr unsigned MPerBlock = a_block_mtx.NCol(); + constexpr unsigned NPerBlock = b_block_mtx.NCol(); + + const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile + + // divide thread work + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + static_assert(MPerBlock % (MPerThread * MThreadPerCluster) == 0, + "MPerBlock % (MPerThread * MThreadPerCluster) != 0"); + + static_assert(NPerBlock % (NPerThread * NThreadPerCluster) == 0, + "NPerBlock % (NPerThread * NThreadPerCluster) != 0"); + + constexpr unsigned MClusterWork = + (MPerBlock + MPerThread * MThreadPerCluster - 1) / (MPerThread * MThreadPerCluster); + + constexpr unsigned NClusterWork = + (NPerBlock + NPerThread * NThreadPerCluster - 1) / (NPerThread * NThreadPerCluster); + + static_assert(BlockSize == (MClusterWork * MThreadPerCluster) * + (NClusterWork * NThreadPerCluster), + "wrong! wrong BlockSize"); + + if(DistributeThreadAlongColumnFirst) + { + const unsigned cluster_work_block_id = + thread_id / (MThreadPerCluster * NThreadPerCluster); + + const unsigned thread_work_cluster_id = + thread_id - cluster_work_block_id * (MThreadPerCluster * NThreadPerCluster); + + const unsigned m_cluster_work_block_id = cluster_work_block_id / NThreadPerCluster; + const unsigned n_cluster_work_block_id = + cluster_work_block_id - m_cluster_work_block_id * NThreadPerCluster; + + const unsigned m_thread_work_cluster_id = + thread_work_cluster_id / NThreadPerCluster; + const unsigned n_thread_work_cluster_id = + thread_work_cluster_id - m_thread_work_cluster_id * NThreadPerCluster; + +#if 0 + if(get_block_1d_id() == 0) + { + printf("%u %u, \t" + //"MClusterWork %u MThreadPerCluster %u NClusterWork %u NThreadPerCluster %u \t" + "m_cluster_work_block_id %u n_cluster_work_block_id %u \t" + "m_thread_work_cluster_id %u n_thread_work_cluster_id %u \t" + "\n", + get_block_1d_id(), get_thread_local_1d_id(), + //MClusterWork, MThreadPerCluster, NClusterWork, NThreadPerCluster, + m_cluster_work_block_id, n_cluster_work_block_id, + m_thread_work_cluster_id, n_thread_work_cluster_id); + } +#endif + + return MatrixIndex{m_cluster_work_block_id * (MThreadPerCluster * MPerThread) + + m_thread_work_cluster_id * MPerThread, + n_cluster_work_block_id * (NThreadPerCluster * NPerThread) + + n_thread_work_cluster_id * NPerThread}; + } + else + { + // not implemented + assert(false); + } + } + else + { + // not implemented + assert(false); + } + } + + template + __device__ void run(FloatA* const p_a_block, + FloatB* const p_b_block, + FloatC* p_c_thread, + Accumulator f_accum) const + { + if(TransA && (!TransB) && (!TransC)) + { + constexpr auto True = Constant{}; + constexpr auto False = Constant{}; + + const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile + const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile + const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile + + constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed + + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + // a is transposed, b is not + const auto a_thread_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}); // constexpr doesn't compile + + const auto b_thread_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}); // constexpr doesn't compile + + FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; + FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; + + // loop over k + for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) + { + threadwise_matrix_copy(a_block_mtx, + p_a_block + mMyThreadOffsetA + + k_begin * a_block_mtx.RowStride(), + a_thread_mtx, + p_a_thread, + a_thread_mtx.GetLengths()); + + threadwise_matrix_copy(b_block_mtx, + p_b_block + mMyThreadOffsetB + + k_begin * b_block_mtx.RowStride(), + b_thread_mtx, + p_b_thread, + b_thread_mtx.GetLengths()); + + threadwise_gemm(a_thread_mtx, + True, + p_a_thread, + b_thread_mtx, + False, + p_b_thread, + c_thread_mtx, + False, + p_c_thread, + f_accum); + } + } + } +}; diff --git a/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh b/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh index dc98754390..820626ce5b 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh @@ -23,11 +23,11 @@ template __global__ void gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc, - Float* const __restrict__ p_in_global, - WeiGlobalDesc, - Float* const __restrict__ p_wei_global, - OutGlobalDesc, - Float* __restrict__ p_out_global) + Float* const __restrict__ p_in_global, + WeiGlobalDesc, + Float* const __restrict__ p_wei_global, + OutGlobalDesc, + Float* __restrict__ p_out_global) { // NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N] // for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N" diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh index 70f401e624..5ab64ca96d 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh @@ -20,7 +20,8 @@ template + unsigned ThreadPerClusterRow, + unsigned ThreadPerClusterColumn> __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, Float* const __restrict__ p_in_global, @@ -112,31 +113,26 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( Number{}, Number{}); // constexpr doesn't compile - static_assert(BPerBlock % BPerBatch == 0 && BPerBatch % BPerThread == 0, "B cannot be evenly divided\n"); - const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( Number{}, - Number{}, + Number{}, Number{}); // constexpr doesn't compile const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor( Number{}, Number{}); // constexpr doesn't compile - const auto blockwise_batched_gemm = - blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c{}; + const auto blockwise_gemm = + blockwise_gemm_block_a_block_b_thread_c{}; // LDS constexpr unsigned in_block_size = in_cb_block_desc.GetElementSpace(); @@ -175,6 +171,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, __syncthreads(); +#if 1 // a series of GEMM for(unsigned s = 0; s < S; ++s) { @@ -182,31 +179,31 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, { auto f_accum = [](auto& c, const auto&& ab) { c += ab; }; - blockwise_batched_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0), + blockwise_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0), p_in_block + s * Wi + r, p_out_thread, f_accum); } } +#endif } // output: register to global mem, const auto matrix_c_index = - blockwise_batched_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); + blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); const unsigned k_thread_data_begin = matrix_c_index.row_begin; - const unsigned b_thread_data_begin = matrix_c_index.batch_begin * BPerBatch + matrix_c_index.col_begin; + const unsigned b_thread_data_begin = matrix_c_index.col_begin; const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin; const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin; #if 0 - //if(get_block_1d_id() == 10) + if(get_block_1d_id() == 0) { - printf("%u %u, batch_begin %u row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", + printf("%u %u, row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", get_block_1d_id(), get_thread_local_1d_id(), - matrix_c_index.batch_begin, matrix_c_index.row_begin, matrix_c_index.col_begin, k_data_begin, @@ -228,7 +225,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, unsigned w_data = itmp - h_data * Wi; #if 0 - if(get_block_1d_id() == 10) + if(get_block_1d_id() == 0) { printf("%u %u, k %u b %u, k_data %u n_data %u h_data %u w_data %u %f\n", get_block_1d_id(),