diff --git a/driver/device_implicit_gemm_convolution.cuh b/driver/device_implicit_gemm_convolution.cuh new file mode 100644 index 0000000000..3819f70e9c --- /dev/null +++ b/driver/device_implicit_gemm_convolution.cuh @@ -0,0 +1,120 @@ +#pragma once +#include "gridwise_implicit_gemm_convolution.cuh" + +template +void device_implicit_gemm_convolution( + InDesc, const Tensor& in, WeiDesc, const Tensor& wei, OutDesc, Tensor& out) +{ + std::size_t data_sz = sizeof(T); + DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(data_sz * wei.mDesc.GetElementSpace()); + DeviceMem out_device_buf(data_sz * out.mDesc.GetElementSpace()); + + int num_thread = std::thread::hardware_concurrency(); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + out_device_buf.ToDevice(out.mData.data()); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_desc = InDesc{}; + constexpr auto wei_desc = WeiDesc{}; + constexpr auto out_desc = OutDesc{}; + +#if 1 + constexpr unsigned OutTileSizeH = 2; + constexpr unsigned OutTileSizeW = 2; + constexpr unsigned NPerBlock = 2; + constexpr unsigned KPerBlock = 32; + constexpr unsigned CPerBlock = 4; + constexpr unsigned YPerBlock = 1; + constexpr unsigned XPerBlock = 16; + + constexpr unsigned NPerThread = 2; + constexpr unsigned KPerThread = 4; + constexpr unsigned CPerThread = 2; + + constexpr unsigned BlockSize = 128; +#elif 0 + constexpr unsigned OutTileSizeH = 2; + constexpr unsigned OutTileSizeW = 2; + constexpr unsigned NPerBlock = 2; + constexpr unsigned KPerBlock = 32; + constexpr unsigned CPerBlock = 4; + constexpr unsigned YPerBlock = 1; + constexpr unsigned XPerBlock = 27; + + constexpr unsigned NPerThread = 2; + constexpr unsigned KPerThread = 4; + constexpr unsigned CPerThread = 2; + + constexpr unsigned BlockSize = 216; +#elif 0 + constexpr unsigned OutTileSizeH = 2; + constexpr unsigned OutTileSizeW = 2; + constexpr unsigned NPerBlock = 2; + constexpr unsigned KPerBlock = 32; + constexpr unsigned CPerBlock = 4; + constexpr unsigned YPerBlock = 1; + constexpr unsigned XPerBlock = 32; + + constexpr unsigned NPerThread = 2; + constexpr unsigned KPerThread = 4; + constexpr unsigned CPerThread = 2; + + constexpr unsigned BlockSize = 256; +#endif + + constexpr unsigned GridSize = (out_desc.GetLength(I0) / NPerBlock) * + (out_desc.GetLength(I1) / KPerBlock) * + (out_desc.GetLength(I2) / (OutTileSizeH * YPerBlock)) * + (out_desc.GetLength(I3) / (OutTileSizeW * XPerBlock)); + + dim3 block_dim(BlockSize); + dim3 grid_dim(GridSize); + + printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); + + cudaEvent_t start, stop; + float elapsedTime; + + cudaEventCreate(&start); + cudaEventRecord(start, 0); + + gridwise_implicit_gemm_convolution + <<>>(InDesc{}, + static_cast(in_device_buf.GetDeviceBuffer()), + WeiDesc{}, + static_cast(wei_device_buf.GetDeviceBuffer()), + OutDesc{}, + static_cast(out_device_buf.GetDeviceBuffer())); + + cudaEventCreate(&stop); + cudaEventRecord(stop, 0); + cudaEventSynchronize(stop); + + cudaEventElapsedTime(&elapsedTime, start, stop); + printf("Elapsed time : %f ms\n", elapsedTime); + + checkCudaErrors(cudaGetLastError()); + out_device_buf.FromDevice(out.mData.data()); +} diff --git a/src/include/gemm.cuh b/src/include/gemm.cuh new file mode 100644 index 0000000000..d7a8800431 --- /dev/null +++ b/src/include/gemm.cuh @@ -0,0 +1,118 @@ +#pragma once + +template +__device__ void threadwise_gemm(ThreadMatrixA, + Constant, + FloatA* const p_a_thread, + ThreadMatrixB, + Constant, + FloatB* const p_b_thread, + ThreadMatrixC, + Constant, + FloatC* p_c_thread, + Accumulator) +{ + // do something +} + +template +struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c +{ + struct MatrixIndex + { + unsigned batch_begin; + unsigned block_row_begin; + unsigned block_col_begin; + }; + + __device__ blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c() + { + static_assert(ThreadMatrixStrideC > 0, "wrong! ThreadMatrixStrideC == 0!"); + + constexpr auto a_block = BlockMatrixA{}; + constexpr auto b_block = BlockMatrixB{}; + + constexpr auto a_thread = ThreadMatrixA{}; + constexpr auto b_thread = ThreadMatrixB{}; + constexpr auto c_thread = ThreadMatrixC{}; + + constexpr unsigned m_block = (!TransA) ? a_block.NRow() : a_block.NCol(); + constexpr unsigned n_block = (!TransB) ? b_block.NCol() : b_block.NRow(); + + constexpr unsigned m_thread = (!TransA) ? a_thread.NRow() : a_thread.NCol(); + constexpr unsigned n_thread = (!TransB) ? b_thread.NCol() : b_thread.NRow(); + + constexpr unsigned num_threads_per_row = (m_block + m_thread - 1) / m_thread; + constexpr unsigned num_threads_per_col = (n_block + n_thread - 1) / n_thread; + constexpr unsigned num_threads_per_batch = num_threads_per_row * num_threads_per_col; + + static_assert(BlockSize >= ((BatchSize + BatchPerThread - 1) / BatchPerThread) * + num_threads_per_batch, + "not enough thread!"); + + const auto mtx_c_idnex = CalculateThreadMatrixCIndex(get_thread_local_id()); + + mMyThreadOffsetA = xxx; + mMyThreadoffSetB = xxx; + } + + __device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const + { + constexpr auto a_block = BlockMatrixA{}; + constexpr auto b_block = BlockMatrixB{}; + constexpr auto c_block = BlockMatrixC{}; + + constexpr auto a_thread = ThreadMatrixA{}; + constexpr auto b_thread = ThreadMatrixB{}; + constexpr auto c_thread = ThreadMatrixC{}; + + constexpr unsigned m_block = (!TransA) ? a_block.NRow() : a_block.NCol(); + constexpr unsigned n_block = (!TransB) ? b_block.NCol() : b_block.NRow(); + + constexpr unsigned m_thread = (!TransA) ? a_thread.NRow() : a_thread.NCol(); + constexpr unsigned n_thread = (!TransB) ? b_thread.NCol() : b_thread.NRow(); + + constexpr unsigned num_threads_per_row = (m_block + m_thread - 1) / m_thread; + constexpr unsigned num_threads_per_col = (n_block + n_thread - 1) / n_thread; + constexpr unsigned num_threads_per_batch = num_threads_per_row * num_threads_per_col; + + // this is wrong, need fix + const unsigned batch_begin = thread_id / (num_threads_per_batch)*BatchPerThread; + const unsigned tmp = thread_id - batch_id * (num_threads_per_row * num_threads_per_col); + const unsigned thread_matrix_row_id = tmp / num_threads_per_row; + const unsigned thread_matrix_col_id = tmp - thread_matrix_row_id * num_threads_per_row; + + return MatrixIndex{ + batch_begin, thread_matrix_row_id * m_thread, thread_matrix_col_id * n_thread}; + } + + template + __device__ void run(FloatA* const p_a_block, FloatB* const p_b_block, FloatC* p_c_thread) const + { + // do something + } + + private: + unsigned mMyThreadOffsetA = 0; + unsigned mMyThreadOffsetB = 0; +} diff --git a/src/include/gridwise_implicit_gemm_convolution.cuh b/src/include/gridwise_implicit_gemm_convolution.cuh new file mode 100644 index 0000000000..f01f5ec4cb --- /dev/null +++ b/src/include/gridwise_implicit_gemm_convolution.cuh @@ -0,0 +1,178 @@ +#pragma once +#include "constant_tensor_descriptor.cuh" +#include "blockwise_tensor_op.cuh" +#include "threadwise_tensor_op.cuh" + +template +__global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, + Float* const __restrict__ p_in_global, + WeiGlobalDesc, + Float* const __restrict__ p_wei_global, + OutGlobalDesc, + Float* __restrict__ p_out_global) +{ + // NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N] + // for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N" + // if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock + constexpr unsigned NPerThread = NPerBlock; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto True = Constant; + constexpr auto False = Constant; + + constexpr auto in_nchw_global_desc = InGlobalDesc{}; + constexpr auto wei_kcsr_global_desc = WeiGlobalDesc{}; + constexpr auto out_nkhw_global_desc = OutGlobalDesc{}; + + constexpr unsigned S = wei_kcsr_global_desc.GetLength(I2); + constexpr unsigned R = wei_kcsr_global_desc.GetLength(I3); + + constexpr unsigned HiPerBlock = HoPerBlock + S - 1; + constexpr unsigned WiPerBlock = WoPerBlock + R - 1; + + // block + constexpr auto in_chwn_block_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto wei_srck_block_desc = + make_ConstantTensorDescriptor(Sequence{}); + + // LDS + constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace(); + constexpr unsigned wei_block_size = wei_srck_block_desc.GetElementSpace(); + + __shared__ Float p_in_block[in_block_size]; + __shared__ Float p_wei_block[wei_block_size]; + + // thread + constexpr auto out_hkwn_thread_desc = xxxxxx(); + + // register + Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()]; + + // set threadwise output tensor to 0 + threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread); + + for(unsigned c_block_data_begin = 0; c_block_data_begin < in_global_desc.GetLength(I1); + c_block_data_begin += CPerBlock, __syncthreads()) + { + // input: global mem to LDS, + // convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N] + constexpr auto reorder_nchw2chwn = Sequence<3, 0, 1, 2>{}; + + blockwise_4d_tensor_copy_reorder(in_nchw_global_desc, + p_in_global, + in_chwn_block_desc, + p_in_block, + in_chwn_block_desc, + reorder_nchw2chwn); + + // matrix view of input + constexpr unsigned in_row = in_chwn_block_desc.GetLength(I0); + constexpr unsigned in_col = in_chwn_block_desc.GetLength(I1) * + in_chwn_block_desc.GetLength(I2) * + in_chwn_block_desc.GetLength(I3); + constexpr auto in_cxhwn_block_mtx_desc = + make_ConstantMatrixDescriptor(Number, Number, Number); + + // weight: global mem to LDS, + // convert 4d-tensor wei[K,C,S,R] to matrix wei_matrix[S*R*C,K] + constexpr auto reorder_kcsr2srck = Sequence<3, 2, 0, 1>{}; + + blockwise_4d_tensor_copy_reorder(wei_csrk_global_desc, + p_wei_global, + wei_csrk_block_desc, + p_wei_block, + wei_csrk_block_desc, + reorder_kcsr2csrk); + + // matrix view of wei + constexpr unsigned wei_row = wei_srck_block_desc.GetLength(I0) * + wei_srck_block_desc.GetLength(I1) * + wei_srck_block_desc.GetLength(I2); + constexpr unsigned wei_col = wei_srck_block_desc.GetLength(I3); + constexpr auto wei_srcxk_block_mtx_desc = + make_ConstantMatrixDescriptor(Number, Number, Number); + + __syncthreads(); + + // a series of batched GEMM + // blockwise batched GEMM, C_matrix += transpose(A_matrix) * B_matrix + // A_matrix and B_matrix saved in LDS, c_matrix saved in register + // A_matrix[C,K] is a sub-matrix of wei_matrix[S*R*C,K] + // B_matrix[C,Wo*N] is a sub-matrix of in_matrix[C,Hi*Wi*N] + // C_matrix[K,Wo*N] is a sub-matrix of out_matrix[Ho*K,Wo*N] + constexpr auto a_block_mtx_desc = wei_srcxk_block_mtx_desc.MakeSubMatrixDescriptor( + Number{}, Number{}); + + constexpr auto b_block_mtx_desc = in_cxhwn_block_mtx_desc.MakeSubMatrixDescriptor( + Number{}, Number{}); + + auto f_accum = (auto& c, auto& v) { c += v; }; + + const auto blockwise_batch_gemm = + blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c{}; + // loop over filter point + for(unsigned s = 0; s < S; ++s) + { + for(unsigned r = 0; r < R; ++r) + { + blockwise_batch_gemm.run( + p_wei_block + wei_srcxk_block_mtx_desc.Get1dIndex(xxxxx, xxxx), + p_in_block + in_cxhwn_block_mtx_desc.Get1dIndex(xxxx, xxxx), + p_out_thread); + } + } + } + + const auto matrix_c_index = + blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_id()); + + const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; + const unsigned k_thread_data_begin = matrix_c_index.col_begin; + const unsigned wo_thread_data_begin = matrix_c_index.row_begin / NPerThread; + + // output: register to global mem, + // convert matrix out_matrix[Ho*K,Wo*N] to 4d-tensor out[N,K,Ho,Wo] + constexpr auto reorder_hkwn2nkhw = Sequence<2, 1, 3, 0>{}; + threadwise_4d_tensor_copy_reorder( + out_hkwn_thread_desc, + p_out_thread, + out_nkhw_global_desc, + p_out_global + out_nkhw_global_desc.GetIndex(n_block_data_begin, + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin), + out_hkwn_thread_desc, + reorder_hkwn2nkhw); +}