diff --git a/driver/conv.cu b/driver/conv.cu index 7852e9bed7..b09e595e07 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -376,7 +376,7 @@ int main() constexpr unsigned K = 64; constexpr unsigned S = 3; constexpr unsigned R = 3; -#elif 0 +#elif 1 constexpr unsigned N = 64; constexpr unsigned C = 256; constexpr unsigned HI = 36; @@ -427,7 +427,7 @@ int main() #endif (in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat); -#if 1 +#if 0 host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host); check_error(out_nkhw_host, out_nkhw_device); #elif 0 diff --git a/driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh b/driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh index 39a2573de2..050176416f 100644 --- a/driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh +++ b/driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh @@ -103,19 +103,6 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc, constexpr unsigned HoPerThread = 2; constexpr unsigned WoPerThread = 1; - constexpr unsigned BlockSize = 128; -#elif 0 - constexpr unsigned NPerBlock = 2; - constexpr unsigned KPerBlock = 32; - constexpr unsigned CPerBlock = 4; - constexpr unsigned HoPerBlock = 2; - constexpr unsigned WoPerBlock = 32; - - constexpr unsigned KPerThread = 4; - constexpr unsigned CPerThread = 2; - constexpr unsigned HoPerThread = 2; - constexpr unsigned WoPerThread = 2; - constexpr unsigned BlockSize = 128; #endif diff --git a/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh b/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh index 6ee62008cc..dd7006d33c 100644 --- a/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh +++ b/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh @@ -75,10 +75,23 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, constexpr unsigned KPerThread = 1; constexpr unsigned CPerThread = 1; - constexpr unsigned ThreadPerClusterRow = 1; - constexpr unsigned ThreadPerClusterColumn = 4; + constexpr unsigned GemmThreadPerClusterRow = 1; + constexpr unsigned GemmThreadPerClusterColumn = 4; constexpr unsigned BlockSize = 32; +#elif 0 + constexpr unsigned BPerBlock = 128; + constexpr unsigned KPerBlock = 64; + constexpr unsigned CPerBlock = 2; + + constexpr unsigned BPerThread = 8; + constexpr unsigned KPerThread = 8; + constexpr unsigned CPerThread = 1; + + constexpr unsigned GemmThreadPerClusterRow = 4; + constexpr unsigned GemmThreadPerClusterColumn = 4; + + constexpr unsigned BlockSize = 128; #elif 1 constexpr unsigned BPerBlock = 128; constexpr unsigned KPerBlock = 64; @@ -88,8 +101,11 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, constexpr unsigned KPerThread = 8; constexpr unsigned CPerThread = 1; - constexpr unsigned ThreadPerClusterRow = 4; - constexpr unsigned ThreadPerClusterColumn = 4; + constexpr unsigned GemmThreadPerClusterRow = 4; + constexpr unsigned GemmThreadPerClusterColumn = 4; + + constexpr unsigned InBlockCopyThreadPerDim0 = 2; + constexpr unsigned InBlockCopyThreadPerDim1 = 64; constexpr unsigned BlockSize = 128; #endif @@ -132,8 +148,10 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, BPerThread, KPerThread, CPerThread, - ThreadPerClusterRow, - ThreadPerClusterColumn> + GemmThreadPerClusterRow, + GemmThreadPerClusterColumn, + InBlockCopyThreadPerDim0, + InBlockCopyThreadPerDim1> <<>>(in_cnhw_desc, static_cast(in_cnhw_device_buf.GetDeviceBuffer()), wei_srck_desc, diff --git a/src/include/blockwise_2d_tensor_op.cuh b/src/include/blockwise_2d_tensor_op.cuh index 43f51793d1..5ce3fad347 100644 --- a/src/include/blockwise_2d_tensor_op.cuh +++ b/src/include/blockwise_2d_tensor_op.cuh @@ -162,11 +162,188 @@ blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc, } template -__device__ void blockwise_2d_tensor_copy( - SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths) +struct blockwise_2d_tensor_copy_1 { - constexpr auto dst_from_src_reorder = Sequence<0, 1>{}; + __device__ void run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const + { + constexpr auto dst_from_src_reorder = Sequence<0, 1>{}; - blockwise_2d_tensor_copy_reorder_by_get_dst_from_src( - SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder); -} + blockwise_2d_tensor_copy_reorder_by_get_dst_from_src( + SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder); + } +}; + +template +struct blockwise_2d_tensor_copy_2 +{ + unsigned mThreadId0; + unsigned mThreadId1; + + __device__ blockwise_2d_tensor_copy_2() + { + mThreadId0 = get_thread_local_1d_id() / ThreadPerDim1; + mThreadId1 = get_thread_local_1d_id() - mThreadId0 * ThreadPerDim1; + } + + __device__ void run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const + { + if(get_thread_local_1d_id() >= ThreadPerDim0 * ThreadPerDim1) + return; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + constexpr auto src_desc = SrcDesc{}; + constexpr auto dst_desc = DstDesc{}; + + constexpr unsigned L0 = SrcOpLengths{}.Get(I0); + constexpr unsigned L1 = SrcOpLengths{}.Get(I1); + + constexpr unsigned Dim0Loop = L0 / ThreadPerDim0; + constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop); + + constexpr unsigned Dim1V4Loop = L1 / (ThreadPerDim1 * 4); + constexpr unsigned Dim1V2Loop = + (L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2); + constexpr unsigned Dim1V1Loop = + (L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) / + ThreadPerDim1; + constexpr bool d1_has_tail = + (L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop)); + + for(unsigned d0loop = 0; d0loop < Dim0Loop; ++d0loop) + { + unsigned did0 = d0loop * ThreadPerDim0 + mThreadId0; + + // v4 + for(unsigned d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop) + { + unsigned did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1; + + for(unsigned i = 0; i < 4; ++i) + { + const unsigned sindex = src_desc.Get1dIndex(did0, did1 + i); + const unsigned dindex = dst_desc.Get1dIndex(did0, did1 + i); + + p_dst[dindex] = p_src[sindex]; + } + } + + // v2 + for(unsigned d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop) + { + unsigned did1 = + Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1; + + for(unsigned i = 0; i < 2; ++i) + { + const unsigned sindex = src_desc.Get1dIndex(did0, did1 + i); + const unsigned dindex = dst_desc.Get1dIndex(did0, did1 + i); + + p_dst[dindex] = p_src[sindex]; + } + } + + // v1 + for(unsigned d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop) + { + unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 + + d1v1loop * ThreadPerDim1 + mThreadId1; + + const unsigned sindex = src_desc.Get1dIndex(did0, did1); + const unsigned dindex = dst_desc.Get1dIndex(did0, did1); + + p_dst[dindex] = p_src[sindex]; + } + + // dim-1 tail + if(d1_has_tail) + { + unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 + + Dim1V1Loop * ThreadPerDim1 + mThreadId1; + + if(did1 < L1) + { + const unsigned sindex = src_desc.Get1dIndex(did0, did1); + const unsigned dindex = dst_desc.Get1dIndex(did0, did1); + + p_dst[dindex] = p_src[sindex]; + } + } + } + + // dim-0 tail + if(d0_has_tail) + { + unsigned did0 = Dim0Loop * ThreadPerDim0 + mThreadId0; + + if(did0 < L0) + { + + // v4 + for(unsigned d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop) + { + unsigned did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1; + + for(unsigned i = 0; i < 4; ++i) + { + const unsigned sindex = src_desc.Get1dIndex(did0, did1 + i); + const unsigned dindex = dst_desc.Get1dIndex(did0, did1 + i); + + p_dst[dindex] = p_src[sindex]; + } + } + + // v2 + for(unsigned d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop) + { + unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + + 2 * mThreadId1; + + for(unsigned i = 0; i < 2; ++i) + { + const unsigned sindex = src_desc.Get1dIndex(did0, did1 + i); + const unsigned dindex = dst_desc.Get1dIndex(did0, did1 + i); + + p_dst[dindex] = p_src[sindex]; + } + } + + // v1 + for(unsigned d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop) + { + unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + + Dim1V2Loop * 2 * ThreadPerDim1 + d1v1loop * ThreadPerDim1 + + mThreadId1; + + const unsigned sindex = src_desc.Get1dIndex(did0, did1); + const unsigned dindex = dst_desc.Get1dIndex(did0, did1); + + p_dst[dindex] = p_src[sindex]; + } + + // tail + if(d1_has_tail) + { + unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + + Dim1V2Loop * 2 * ThreadPerDim1 + Dim1V1Loop * ThreadPerDim1 + + mThreadId1; + + if(did1 < L1) + { + const unsigned sindex = src_desc.Get1dIndex(did0, did1); + const unsigned dindex = dst_desc.Get1dIndex(did0, did1); + + p_dst[dindex] = p_src[sindex]; + } + } + } + } + } +}; diff --git a/src/include/common.cuh b/src/include/common.cuh index ba910853d7..9885cabf84 100644 --- a/src/include/common.cuh +++ b/src/include/common.cuh @@ -1,5 +1,7 @@ #pragma once +#define WARPSIZE 32; + template struct is_same { diff --git a/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh b/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh index 820626ce5b..04c9000a9e 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh @@ -153,6 +153,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc, for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1); c_block_data_begin += CPerBlock, __syncthreads()) { +#if 1 // input: global mem to LDS, // convert [N,C,Hi,Wi] to [C,Hi,Wi,N] blockwise_4d_tensor_copy_reorder_by_get_dst_from_src( @@ -165,7 +166,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc, p_in_block, in_nchw_block_desc.GetLengths(), reorder_chwn_from_nchw); +#endif +#if 1 // weight: global mem to LDS, // format is [S,R,C,K], no conversion needed blockwise_4d_tensor_copy( @@ -175,6 +178,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc, wei_srck_block_desc, p_wei_block, wei_srck_block_desc.GetLengths()); +#endif __syncthreads(); diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh index 2a609f046a..0fc4381a9b 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh @@ -20,8 +20,10 @@ template + unsigned GemmThreadPerClusterRow, + unsigned GemmThreadPerClusterColumn, + unsigned InBlockCopyThreadPerDim0, + unsigned InBlockCopyThreadPerDim1> __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, Float* const __restrict__ p_in_global, @@ -104,6 +106,26 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, } #endif +#if 1 + // blockwise 2d copy + const auto blockwise_2d_copy = + blockwise_2d_tensor_copy_1{}; +#elif 0 + // blockwise 2d copy + const auto blockwise_2d_copy = + blockwise_2d_tensor_copy_2{}; +#endif + // a series of blockwise GEMM // c_mtx += transpose(a_mtx) * b_mtx // a_mtx and b_mtx saved in LDS, c_mtx saved in register @@ -130,8 +152,8 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, false, false, CPerThread, - ThreadPerClusterRow, - ThreadPerClusterColumn, + GemmThreadPerClusterRow, + GemmThreadPerClusterColumn, true>{}; // LDS @@ -152,12 +174,9 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, { // input: global mem to LDS, // formmat is [CPerBlock,BPerBlock + BGhostRead] - blockwise_2d_tensor_copy( - in_cb_global_desc, + blockwise_2d_copy.run( p_in_global + in_cb_global_desc.Get1dIndex(c_block_data_begin, b_block_data_begin), - in_cb_block_desc, - p_in_block, - in_cb_block_desc.GetLengths()); + p_in_block); // weight: global mem to LDS, // format is [S,R,CPerBlock,KPerBlock] @@ -245,22 +264,6 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, p_out_global[out_knhw_global_desc.Get1dIndex(k_data, n_data, h_data, w_data)] = p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; #endif - -#if 0 - if(get_block_1d_id() == 0) - { - printf("%u %u, k %u b %u, k_data %u n_data %u h_data %u w_data %u %f\n", - get_block_1d_id(), - get_thread_local_1d_id(), - k, - b, - k_data, - n_data, - h_data, - w_data, - p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]); - } -#endif } } }