diff --git a/driver/conv.cu b/driver/conv.cu index 88aae323f3..42a9e950e5 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -9,7 +9,7 @@ #include "device_direct_convolution_1.cuh" #include "device_direct_convolution_2.cuh" #include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh" -#include "device_implicit_gemm_convolution_1_nchw_srck.cuh" +#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh" #include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh" //#include "device_winograd_convolution.cuh" @@ -418,8 +418,8 @@ int main() device_direct_convolution_2 #elif 0 device_implicit_gemm_convolution_1_nchw_kcsr -#elif 0 - device_implicit_gemm_convolution_1_nchw_srck +#elif 1 + device_implicit_gemm_convolution_1_nchw_srck_nkhw #elif 1 device_implicit_gemm_convolution_2_cnhw_srck_knhw #elif 0 diff --git a/driver/device_implicit_gemm_convolution_1_nchw_srck.cuh b/driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh similarity index 62% rename from driver/device_implicit_gemm_convolution_1_nchw_srck.cuh rename to driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh index b7c527fb89..293b46b5a1 100644 --- a/driver/device_implicit_gemm_convolution_1_nchw_srck.cuh +++ b/driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh @@ -1,13 +1,15 @@ #pragma once -#include "gridwise_implicit_gemm_convolution_1_nchw_srck.cuh" +#include "gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh" +#include template -void device_implicit_gemm_convolution_1_nchw_srck(InDesc, +void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc, const Tensor& in_nchw, WeiDesc, const Tensor& wei_kcsr, OutDesc, - Tensor& out_nkhw) + Tensor& out_nkhw, + unsigned nrepeat) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -101,6 +103,19 @@ void device_implicit_gemm_convolution_1_nchw_srck(InDesc, constexpr unsigned HoPerThread = 2; constexpr unsigned WoPerThread = 1; + constexpr unsigned BlockSize = 128; +#elif 1 + constexpr unsigned NPerBlock = 2; + constexpr unsigned KPerBlock = 32; + constexpr unsigned CPerBlock = 4; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 32; + + constexpr unsigned KPerThread = 4; + constexpr unsigned CPerThread = 2; + constexpr unsigned HoPerThread = 2; + constexpr unsigned WoPerThread = 2; + constexpr unsigned BlockSize = 128; #endif @@ -113,40 +128,46 @@ void device_implicit_gemm_convolution_1_nchw_srck(InDesc, printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - cudaEvent_t start, stop; - float elapsedTime; + for(unsigned i = 0; i < nrepeat; ++i) + { + cudaEvent_t start, stop; + float elapsedTime; - cudaEventCreate(&start); - cudaEventRecord(start, 0); + cudaEventCreate(&start); + cudaEventRecord(start, 0); - gridwise_implicit_gemm_convolution_1_nchw_srck - <<>>(in_nchw_desc, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - wei_srck_desc, - static_cast(wei_srck_device_buf.GetDeviceBuffer()), - out_nkhw_desc, - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw + <<>>(in_nchw_desc, + static_cast(in_nchw_device_buf.GetDeviceBuffer()), + wei_srck_desc, + static_cast(wei_srck_device_buf.GetDeviceBuffer()), + out_nkhw_desc, + static_cast(out_nkhw_device_buf.GetDeviceBuffer())); - cudaEventCreate(&stop); - cudaEventRecord(stop, 0); - cudaEventSynchronize(stop); + cudaEventCreate(&stop); + cudaEventRecord(stop, 0); + cudaEventSynchronize(stop); + + cudaEventElapsedTime(&elapsedTime, start, stop); + printf("Elapsed time : %f ms\n", elapsedTime); + + usleep(10); + } - cudaEventElapsedTime(&elapsedTime, start, stop); - printf("Elapsed time : %f ms\n", elapsedTime); checkCudaErrors(cudaGetLastError()); out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); diff --git a/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh b/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh index be62bcb24a..21b5c3b43e 100644 --- a/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh +++ b/driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh @@ -90,6 +90,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, constexpr unsigned KPerBlock = 64; constexpr unsigned CPerBlock = 2; + constexpr unsigned BPerBatch = 32; + constexpr unsigned BPerThread = 4; constexpr unsigned KPerThread = 16; constexpr unsigned CPerThread = 1; @@ -134,7 +136,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, CPerBlock, BPerThread, KPerThread, - CPerThread> + CPerThread, + BPerBatch> <<>>(in_cnhw_desc, static_cast(in_cnhw_device_buf.GetDeviceBuffer()), wei_srck_desc, diff --git a/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck.cuh b/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh similarity index 99% rename from src/include/gridwise_implicit_gemm_convolution_1_nchw_srck.cuh rename to src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh index 2550336bea..dc98754390 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh @@ -22,7 +22,7 @@ template __global__ void -gridwise_implicit_gemm_convolution_1_nchw_srck(InGlobalDesc, +gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc, Float* const __restrict__ p_in_global, WeiGlobalDesc, Float* const __restrict__ p_wei_global, diff --git a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh index 9e456c0219..70f401e624 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh @@ -19,7 +19,8 @@ template + unsigned CPerThread, + unsigned BPerBatch> __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, Float* const __restrict__ p_in_global, @@ -111,15 +112,17 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( Number{}, Number{}); // constexpr doesn't compile + static_assert(BPerBlock % BPerBatch == 0 && BPerBatch % BPerThread == 0, "B cannot be evenly divided\n"); + const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( Number{}, - Number{}, + Number{}, Number{}); // constexpr doesn't compile const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor( Number{}, Number{}); // constexpr doesn't compile - const auto blockwise_gemm = + const auto blockwise_batched_gemm = blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c{}; @@ -179,7 +182,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, { auto f_accum = [](auto& c, const auto&& ab) { c += ab; }; - blockwise_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0), + blockwise_batched_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0), p_in_block + s * Wi + r, p_out_thread, f_accum); @@ -189,10 +192,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, // output: register to global mem, const auto matrix_c_index = - blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); + blockwise_batched_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); const unsigned k_thread_data_begin = matrix_c_index.row_begin; - const unsigned b_thread_data_begin = matrix_c_index.col_begin; + const unsigned b_thread_data_begin = matrix_c_index.batch_begin * BPerBatch + matrix_c_index.col_begin; const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin; const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;