diff --git a/driver/conv.cu b/driver/conv.cu index 5d039a4f86..b5bf0a2473 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -614,7 +614,7 @@ int main() nrepeat); #endif -#if 1 +#if 0 if(S == 3 && R == 3) { host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); diff --git a/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh b/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh index e801e15b2b..7f3ddc7299 100644 --- a/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh +++ b/driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh @@ -128,7 +128,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, constexpr unsigned BlockSize = 64; #elif 1 - // 1x1, 28x28, 128 threads + // 1x1, 28x28, 128 threads, no lds-double-buffer + // 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128 constexpr unsigned BPerBlock = 64; constexpr unsigned KPerBlock = 128; constexpr unsigned CPerBlock = 8; @@ -215,37 +216,37 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc, cudaEventCreate(&start); cudaEventRecord(start, 0); -#if 1 +#if 0 gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw #else gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer #endif - + <<>>(in_cnhw_desc, static_cast(in_cnhw_device_buf.GetDeviceBuffer()), wei_csrk_desc, diff --git a/src/include/blockwise_2d_tensor_op.cuh b/src/include/blockwise_2d_tensor_op.cuh index 1231a3863b..ff4476d0b7 100644 --- a/src/include/blockwise_2d_tensor_op.cuh +++ b/src/include/blockwise_2d_tensor_op.cuh @@ -512,4 +512,196 @@ struct Blockwise2dTensorCopy3 } } } + +#if 1 + __device__ constexpr unsigned GetRegisterClipboardSize() const + { + static_assert(is_same::value, "wrong! only support float!\n"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + constexpr unsigned L0 = CopyLengths{}.Get(I0); + constexpr unsigned L1 = CopyLengths{}.Get(I1); + + constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead; + constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1; + + return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0; + } + + __device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src, + Float* p_clipboard) const + { + static_assert(is_same::value, "wrong! only support float!\n"); + + using Float2 = float2; + using Float4 = float4; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + constexpr unsigned L0 = CopyLengths{}.Get(I0); + constexpr unsigned L1 = CopyLengths{}.Get(I1); + + constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead; + constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1; + + constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1; + + if(BlockSize > num_active_thread) + { + if(get_thread_local_1d_id() >= num_active_thread) + { + return; + } + } + + constexpr unsigned nloop_d0 = L0 / thread_per_d0; + + constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0; + constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0; + + for(unsigned iloop = 0; iloop < nloop_d0; ++iloop) + { + if(DataPerRead == 1) + { + p_clipboard[iloop] = p_src[mSrcMyThreadOffset + iloop * src_loop_stride]; + } + else if(DataPerRead == 2) + { + *(reinterpret_cast(p_clipboard + iloop * 2)) = + *(reinterpret_cast(p_src + mSrcMyThreadOffset + + iloop * src_loop_stride)); + } + else if(DataPerRead == 4) + { + *(reinterpret_cast(p_clipboard + iloop * 4)) = + *(reinterpret_cast(p_src + mSrcMyThreadOffset + + iloop * src_loop_stride)); + } + else + { + assert(false); + } + } + + constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0); + + if(has_tail_d0) + { + constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0; + + if(get_thread_local_1d_id() < tail_d0 * thread_per_d1) + { + if(DataPerRead == 1) + { + p_clipboard[nloop_d0] = p_src[mSrcMyThreadOffset + nloop_d0 * src_loop_stride]; + } + else if(DataPerRead == 2) + { + *(reinterpret_cast(p_clipboard + nloop_d0 * 2)) = + *(reinterpret_cast(p_src + mSrcMyThreadOffset + + nloop_d0 * src_loop_stride)); + } + else if(DataPerRead == 4) + { + *(reinterpret_cast(p_clipboard + nloop_d0 * 4)) = + *(reinterpret_cast(p_src + mSrcMyThreadOffset + + nloop_d0 * src_loop_stride)); + } + else + { + assert(false); + } + } + } + } + + __device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard, + Float* __restrict__ p_dst) const + { + static_assert(is_same::value, "wrong! only support float!\n"); + + using Float2 = float2; + using Float4 = float4; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + constexpr unsigned L0 = CopyLengths{}.Get(I0); + constexpr unsigned L1 = CopyLengths{}.Get(I1); + + constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead; + constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1; + + constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1; + + if(BlockSize > num_active_thread) + { + if(get_thread_local_1d_id() >= num_active_thread) + { + return; + } + } + + constexpr unsigned nloop_d0 = L0 / thread_per_d0; + + constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0; + constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0; + + for(unsigned iloop = 0; iloop < nloop_d0; ++iloop) + { + if(DataPerRead == 1) + { + p_dst[mDstMyThreadOffset + iloop * dst_loop_stride] = p_clipboard[iloop]; + } + else if(DataPerRead == 2) + { + *(reinterpret_cast(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) = + *(reinterpret_cast(p_clipboard + iloop * 2)); + } + else if(DataPerRead == 4) + { + *(reinterpret_cast(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) = + *(reinterpret_cast(p_clipboard + iloop * 4)); + } + else + { + assert(false); + } + } + + constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0); + + if(has_tail_d0) + { + constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0; + + if(get_thread_local_1d_id() < tail_d0 * thread_per_d1) + { + if(DataPerRead == 1) + { + p_dst[mDstMyThreadOffset + nloop_d0 * dst_loop_stride] = p_clipboard[nloop_d0]; + } + else if(DataPerRead == 2) + { + *(reinterpret_cast(p_dst + mDstMyThreadOffset + + nloop_d0 * dst_loop_stride)) = + *(reinterpret_cast(p_clipboard + nloop_d0 * 2)); + } + else if(DataPerRead == 4) + { + *(reinterpret_cast(p_dst + mDstMyThreadOffset + + nloop_d0 * dst_loop_stride)) = + *(reinterpret_cast(p_clipboard + nloop_d0 * 4)); + } + else + { + assert(false); + } + } + } + } +#endif }; diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer.cuh b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer.cuh index 7e10a62faf..16165f6195 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer.cuh @@ -262,8 +262,26 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b __syncthreads(); // load next data +#if 0 blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next); blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next); +#elif 0 + blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next); + + Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_register_clipboard); +#elif 1 + Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; + Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, + p_in_register_clipboard); + + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_register_clipboard); +#endif // compute on current data // a series of GEMM @@ -283,6 +301,13 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b f_accum); } } + +#if 0 + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next); +#elif 1 + blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next); + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next); +#endif } // last computation