diff --git a/driver/conv.cu b/driver/conv.cu index 898a6b76af..64a4ceb714 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -16,21 +16,7 @@ struct GeneratorTensor_1 template double operator()(Is... is) { -#if 0 - return double(std::rand()) / double(RAND_MAX); -#elif 1 return 1; -#elif 0 - std::initializer_list ls = {static_cast(is)...}; - return std::accumulate(ls.begin(), ls.end(), std::size_t(0)); -#else - assert(sizeof...(Is) > 0); - std::initializer_list ids = {static_cast(is)...}; - std::vector lens(sizeof...(Is), 100); - std::vector strides(sizeof...(Is), 1); - std::partial_sum(lens.rbegin(), lens.rbegin() + (sizeof...(Is) - 1), strides.rbegin() + 1); - return std::inner_product(ids.begin(), ids.end(), strides.begin(), std::size_t(0)) + 1; -#endif } }; @@ -46,6 +32,25 @@ struct GeneratorTensor_2 } }; +struct GeneratorTensor_3 +{ + template + double operator()(Is... is) + { +#if 0 + std::initializer_list ls = {static_cast(is)...}; + return std::accumulate(ls.begin(), ls.end(), std::size_t(0)); +#elif 1 + assert(sizeof...(Is) > 0); + std::initializer_list ids = {static_cast(is)...}; + std::vector lens(sizeof...(Is), 100); + std::vector strides(sizeof...(Is), 1); + std::partial_sum(lens.rbegin(), lens.rbegin() + (sizeof...(Is) - 1), strides.rbegin() + 1); + return std::inner_product(ids.begin(), ids.end(), strides.begin(), std::size_t(0)) + 1; +#endif + } +}; + // this is ugly, only for 4d template void ostream_ConstantTensorDescriptor(TConstTensorDesc, std::ostream& os = std::cout) @@ -338,7 +343,7 @@ int main() constexpr unsigned K = 1; constexpr unsigned S = 3; constexpr unsigned R = 3; -#elif 0 +#elif 1 constexpr unsigned N = 1; constexpr unsigned C = 1; constexpr unsigned HI = 34; @@ -347,21 +352,21 @@ int main() constexpr unsigned S = 3; constexpr unsigned R = 3; #elif 1 - constexpr unsigned N = 64; - constexpr unsigned C = 256; + constexpr unsigned N = 64; + constexpr unsigned C = 256; constexpr unsigned HI = 34; constexpr unsigned WI = 34; - constexpr unsigned K = 64; - constexpr unsigned S = 3; - constexpr unsigned R = 3; + constexpr unsigned K = 64; + constexpr unsigned S = 3; + constexpr unsigned R = 3; #elif 0 - constexpr unsigned N = 64; - constexpr unsigned C = 64; + constexpr unsigned N = 64; + constexpr unsigned C = 64; constexpr unsigned HI = 56; constexpr unsigned WI = 56; - constexpr unsigned K = 64; - constexpr unsigned S = 3; - constexpr unsigned R = 3; + constexpr unsigned K = 64; + constexpr unsigned S = 3; + constexpr unsigned R = 3; #elif 0 constexpr unsigned N = 64; constexpr unsigned C = 64; @@ -374,34 +379,51 @@ int main() auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence{}); auto wei_kcsr_desc = make_ConstantTensorDescriptor(Sequence{}); - auto wei_srck_desc = make_ConstantTensorDescriptor(Sequence{}); auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor(in_nchw_desc, wei_kcsr_desc); ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_ConstantTensorDescriptor(wei_kcsr_desc, std::cout << "wei_kcsr_desc: "); - ostream_ConstantTensorDescriptor(wei_srck_desc, std::cout << "wei_srck_desc: "); ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); Tensor in_nchw(make_TensorDescriptor(in_nchw_desc)); Tensor wei_kcsr(make_TensorDescriptor(wei_kcsr_desc)); - Tensor wei_srck(make_TensorDescriptor(wei_srck_desc)); Tensor out_nkhw_host(make_TensorDescriptor(out_nkhw_desc)); Tensor out_nkhw_device(make_TensorDescriptor(out_nkhw_desc)); -#if 0 std::size_t num_thread = std::thread::hardware_concurrency(); + +#if 0 in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei_srck.GenerateTensorValue(GeneratorTensor_1{}, num_thread); -#elif 1 - std::size_t num_thread = std::thread::hardware_concurrency(); +#elif 0 in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcsr.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei_srck.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); +#elif 0 + in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei_kcsr.GenerateTensorValue(GeneratorTensor_3{}, num_thread); +#elif 1 + in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread); + wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread); #endif - for(int i = 0; i < 40; ++i) +#if 1 + auto wei_srck_desc = make_ConstantTensorDescriptor(Sequence{}); + Tensor wei_srck(make_TensorDescriptor(wei_srck_desc)); + + auto f_reorder_kcsr2srck = [&](auto k, auto c, auto s, auto r) { + wei_srck(s, r, c, k) = wei_kcsr(k, c, s, r); + }; + + make_ParallelTensorFunctor(f_reorder_kcsr2srck, K, C, S, R)(num_thread); +#endif + +#if 0 + wei_srck.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + out_nkhw_device.GenerateTensorValue(GeneratorTensor_1{}, num_thread); +#endif + + for(int i = 0; i < 1; ++i) { #if 0 device_direct_convolution_1(in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device); @@ -428,7 +450,7 @@ int main() check_error(out_nkhw_host, out_nkhw_device); #endif -#if 0 +#if 1 LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; LogRange(std::cout << "wei_kcsr: ", wei_kcsr.mData, ",") << std::endl; LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; diff --git a/driver/device_implicit_gemm_convolution.cuh b/driver/device_implicit_gemm_convolution.cuh index 5eb9521653..1f776a2974 100644 --- a/driver/device_implicit_gemm_convolution.cuh +++ b/driver/device_implicit_gemm_convolution.cuh @@ -1,5 +1,5 @@ #pragma once -#include "gridwise_implicit_gemm_convolution_nchw_kcsr.cuh" +//#include "gridwise_implicit_gemm_convolution_nchw_kcsr.cuh" #include "gridwise_implicit_gemm_convolution_nchw_srck.cuh" template @@ -26,20 +26,20 @@ void device_implicit_gemm_convolution( constexpr auto wei_desc = WeiDesc{}; constexpr auto out_desc = OutDesc{}; -#if 0 - constexpr unsigned NPerBlock = 2; - constexpr unsigned KPerBlock = 64; - constexpr unsigned CPerBlock = 4; +#if 1 + constexpr unsigned NPerBlock = 1; + constexpr unsigned KPerBlock = 1; + constexpr unsigned CPerBlock = 1; constexpr unsigned HoPerBlock = 2; constexpr unsigned WoPerBlock = 32; - constexpr unsigned NPerThread = 2; - constexpr unsigned KPerThread = 8; - constexpr unsigned CPerThread = 2; - constexpr unsigned HoPerThread = 1; - constexpr unsigned WoPerThread = 4; + constexpr unsigned NPerThread = 1; + constexpr unsigned KPerThread = 1; + constexpr unsigned CPerThread = 1; + constexpr unsigned HoPerThread = 2; + constexpr unsigned WoPerThread = 2; - constexpr unsigned BlockSize = 256; + constexpr unsigned BlockSize = 16; #elif 1 constexpr unsigned NPerBlock = 2; constexpr unsigned KPerBlock = 32; @@ -50,7 +50,7 @@ void device_implicit_gemm_convolution( constexpr unsigned NPerThread = 2; constexpr unsigned KPerThread = 4; constexpr unsigned CPerThread = 2; - constexpr unsigned HoPerThread = 1; + constexpr unsigned HoPerThread = 2; constexpr unsigned WoPerThread = 2; constexpr unsigned BlockSize = 128; diff --git a/src/include/ConstantMatrixDescriptor.cuh b/src/include/ConstantMatrixDescriptor.cuh index 352e0db0d8..6e6dfb694b 100644 --- a/src/include/ConstantMatrixDescriptor.cuh +++ b/src/include/ConstantMatrixDescriptor.cuh @@ -1,34 +1,36 @@ #pragma once #include "common.cuh" -template +template struct ConstantMatrixDescriptor { __host__ __device__ ConstantMatrixDescriptor() { - static_assert(NCol <= RowStride, "wrong! NCol > RowStride!"); + static_assert(NCol_ <= RowStride_, "wrong! NCol > RowStride!"); } - __host__ __device__ constexpr unsigned GetNumberOfRow() const { return NRow; } + __host__ __device__ constexpr unsigned NRow() const { return NRow_; } - __host__ __device__ constexpr unsigned GetNumberOfColumn() const { return NCol; } + __host__ __device__ constexpr unsigned NCol() const { return NCol_; } - __host__ __device__ constexpr unsigned GetRowStride() const { return RowStride; } + __host__ __device__ constexpr unsigned RowStride() const { return RowStride_; } - __host__ __device__ constexpr unsigned GetElementSize() const { return NRow * NCol; } + __host__ __device__ constexpr auto GetLengths() const { return Sequence{}; } - __host__ __device__ constexpr unsigned GetElementSpace() const { return NRow * RowStride; } + __host__ __device__ constexpr unsigned GetElementSize() const { return NRow_ * NCol_; } + + __host__ __device__ constexpr unsigned GetElementSpace() const { return NRow_ * RowStride_; } __host__ __device__ unsigned Get1dIndex(unsigned irow, unsigned icol) const { - return irow * RowStride + icol; + return irow * RowStride_ + icol; } template __host__ __device__ constexpr auto MakeSubMatrixDescriptor(Number, Number) const { - return ConstantMatrixDescriptor{}; + return ConstantMatrixDescriptor{}; } }; diff --git a/src/include/blockwise_tensor_op.cuh b/src/include/blockwise_tensor_op.cuh index 13e2093333..ed5b080e0a 100644 --- a/src/include/blockwise_tensor_op.cuh +++ b/src/include/blockwise_tensor_op.cuh @@ -135,6 +135,20 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]); +#if 1 + printf("did %u %u %u %u, did_IR %u %u %u %u, index %u %u\n", + did[0], + did[1], + did[2], + did[3], + did[IR0], + did[IR1], + did[IR2], + did[IR3], + aindex, + bindex); +#endif + f(p_src[aindex], p_dst[bindex]); } diff --git a/src/include/gemm.cuh b/src/include/gemm.cuh index 6e08b0b1b1..0a3789580f 100644 --- a/src/include/gemm.cuh +++ b/src/include/gemm.cuh @@ -1,8 +1,27 @@ #pragma once -template +__device__ void +threadwise_matrix_copy(SrcMatrix, Float* const p_src, DstMatrix, Float* p_dst, Sequence) +{ + const auto src_mtx = SrcMatrix{}; // constexpr doesn't compile + const auto dst_mtx = DstMatrix{}; // constexpr doesn't compile + + for(unsigned i = 0; i < NRow; ++i) + { + for(unsigned j = 0; j < NCol; ++j) + { + const unsigned src_index = src_mtx.Get1dIndex(i, j); + const unsigned dst_index = dst_mtx.Get1dIndex(i, j); + + p_dst[dst_index] = p_src[src_index]; + } + } +} + +template -__device__ void threadwise_gemm(ThreadMatrixA, +__device__ void threadwise_gemm(MatrixA, Constant, FloatA* const p_a_thread, - ThreadMatrixB, + MatrixB, Constant, FloatB* const p_b_thread, - ThreadMatrixC, + MatrixC, Constant, FloatC* p_c_thread, - Accumulator) + Accumulator f_accum) { - // do something + if(TransA && (!TransB) && (!TransC)) + { + const auto a_mtx = MatrixA{}; // constexpr doesn't compile + const auto b_mtx = MatrixB{}; // constexpr doesn't compile + const auto c_mtx = MatrixC{}; // constexpr doesn't compile + + constexpr unsigned M = c_mtx.NRow(); + constexpr unsigned N = c_mtx.NCol(); + constexpr unsigned K = a_mtx.NRow(); // A is transposed + + for(unsigned i = 0; i < M; ++i) + { + for(unsigned j = 0; j < N; ++j) + { + for(unsigned k = 0; k < K; ++k) + { + const unsigned aindex = a_mtx.Get1dIndex(k, i); // A is transposed + const unsigned bindex = b_mtx.Get1dIndex(k, j); + const unsigned cindex = c_mtx.Get1dIndex(i, j); + + f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]); + } + } + } + } + else + { + // not implemented + assert(false); + } } template + unsigned KPerThreadLoop, + bool DistributeThreadAlongColumnFirst> struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c { unsigned mMyThreadOffsetA = 0; @@ -52,82 +100,177 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c __device__ blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c() { - static_assert(ThreadMatrixStrideC > 0, "wrong! ThreadMatrixStrideC == 0!"); + const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile + const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile -#if 0 - constexpr auto a_block_desc = BlockMatrixA{}; - constexpr auto b_block_desc = BlockMatrixB{}; + const auto c_thread_mtx_index = CalculateThreadMatrixCIndex(get_thread_local_1d_id()); - constexpr unsigned a_thread_row = (!TransA) ? MPerThread : KPerThread; - constexpr unsigned a_thread_col = (!TransA) ? KPerThread : MPerThread; - constexpr unsigned b_thread_row = (!TransB) ? KPerThread : NPerThread; - constexpr unsigned b_thread_col = (!TransB) ? NPerThread : KPerThread; + mMyThreadOffsetA = c_thread_mtx_index.batch_begin * a_block_mtx.GetElementSpace() + + ((!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row_begin, 0) + : a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row_begin)); - constexpr auto a_thread_desc = ConstantMatrixDescriptor{}; - constexpr auto b_thread_desc = ConstantMatrixDescriptor{}; - constexpr auto c_thread_desc = ConstantMatrixDescriptor{}; - - constexpr unsigned m_block = (!TransA) ? a_block_desc.NRow() : a_block_desc.NCol(); - constexpr unsigned n_block = (!TransB) ? b_block_desc.NCol() : b_block_desc.NRow(); - - constexpr unsigned m_thread = (!TransA) ? a_thread_desc.NRow() : a_thread_desc.NCol(); - constexpr unsigned n_thread = (!TransB) ? b_thread_desc.NCol() : b_thread_desc.NRow(); - - constexpr unsigned num_threads_per_row = (m_block + m_thread - 1) / m_thread; - constexpr unsigned num_threads_per_col = (n_block + n_thread - 1) / n_thread; - constexpr unsigned num_threads_per_batch = num_threads_per_row * num_threads_per_col; - - static_assert(BlockSize >= ((BatchSize + BatchPerThread - 1) / BatchPerThread) * - num_threads_per_batch, - "not enough thread!"); - - const auto mtx_c_idnex = CalculateThreadMatrixCIndex(get_thread_local_id()); - - // mMyThreadOffsetA = xxx; - // mMyThreadoffSetB = xxx; -#else - mMyThreadOffsetA = 0; - mMyThreadOffsetB = 0; -#endif + mMyThreadOffsetB = c_thread_mtx_index.batch_begin * b_block_mtx.GetElementSpace() + + ((!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col_begin) + : b_block_mtx.Get1dIndex(c_thread_mtx_index.col_begin, 0)); } __device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const { -#if 0 - constexpr auto a_block = BlockMatrixA{}; - constexpr auto b_block = BlockMatrixB{}; - constexpr auto c_block = BlockMatrixC{}; - constexpr auto a_thread = ThreadMatrixA{}; - constexpr auto b_thread = ThreadMatrixB{}; - constexpr auto c_thread = ThreadMatrixC{}; + if(TransA && (!TransB) && (!TransC)) + { + const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile + const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile - constexpr unsigned m_block = (!TransA) ? a_block.NRow() : a_block.NCol(); - constexpr unsigned n_block = (!TransB) ? b_block.NCol() : b_block.NRow(); + static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), + "wrong! k dimension not consistent!"); - constexpr unsigned m_thread = (!TransA) ? a_thread.NRow() : a_thread.NCol(); - constexpr unsigned n_thread = (!TransB) ? b_thread.NCol() : b_thread.NRow(); + constexpr unsigned MPerBlock = a_block_mtx.NCol(); + constexpr unsigned NPerBlock = b_block_mtx.NCol(); - constexpr unsigned num_threads_per_row = (m_block + m_thread - 1) / m_thread; - constexpr unsigned num_threads_per_col = (n_block + n_thread - 1) / n_thread; - constexpr unsigned num_threads_per_batch = num_threads_per_row * num_threads_per_col; + const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile - // this is wrong, need fix - const unsigned batch_begin = thread_id / (num_threads_per_batch)*BatchPerThread; - const unsigned tmp = thread_id - batch_id * (num_threads_per_row * num_threads_per_col); - const unsigned thread_matrix_row_id = tmp / num_threads_per_row; - const unsigned thread_matrix_col_id = tmp - thread_matrix_row_id * num_threads_per_row; + // divide thread work + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); - return MatrixIndex{ - batch_begin, thread_matrix_row_id * m_thread, thread_matrix_col_id * n_thread}; -#else - return MatrixIndex{0, 0, 0}; -#endif + static_assert(BatchSize % BatchPerThread == 0, "BatchSize % BatchPerThread != 0"); + static_assert(MPerBlock % MPerThread == 0, "MPerBlock % MPerThread != 0"); + static_assert(NPerBlock % NPerThread == 0, "NPerBlock % NPerThread != 0"); + + constexpr unsigned BThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread; + constexpr unsigned MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread; + constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; + + static_assert(BlockSize == BThreadWork * MThreadWork * NThreadWork, + "wrong! wrong BlockSize"); + + // printf("%u %u, %u %u\n", get_block_1d_id(), get_thread_local_1d_id(), MThreadWork, + // NThreadWork); + + if(DistributeThreadAlongColumnFirst) + { + // num of operations can be reduced + const unsigned b_work_id = thread_id / (MThreadWork * NThreadWork); + unsigned itmp = thread_id - b_work_id * (MThreadWork * NThreadWork); + const unsigned m_work_id = itmp / NThreadWork; + const unsigned n_work_id = itmp - m_work_id * NThreadWork; + + return MatrixIndex{ + b_work_id * BatchPerThread, m_work_id * MPerThread, n_work_id * NPerThread}; + } + else + { + // not implemented + assert(false); + } + } + else + { + // not implemented + assert(false); + } } - template - __device__ void run(FloatA* const p_a_block, FloatB* const p_b_block, FloatC* p_c_thread) const + template + __device__ void run(FloatA* const p_a_block, + FloatB* const p_b_block, + FloatC* p_c_thread, + Accumulator f_accum) const { - // do something + if(TransA && (!TransB) && (!TransC)) + { + constexpr auto True = Constant{}; + constexpr auto False = Constant{}; + + const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile + const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile + const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile + + constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed + + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + // a is transposed, b is not + const auto a_thread_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}); // constexpr doesn't compile + + const auto b_thread_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}); // constexpr doesn't compile + + FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; + FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; + + // loop over k + for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) + { + // read first batch of a, b + threadwise_matrix_copy(a_block_mtx, + p_a_block + mMyThreadOffsetA + + k_begin * a_block_mtx.RowStride(), + a_thread_mtx, + p_a_thread, + a_thread_mtx.GetLengths()); + + threadwise_matrix_copy(b_block_mtx, + p_b_block + mMyThreadOffsetB + + k_begin * b_block_mtx.RowStride(), + b_thread_mtx, + p_b_thread, + b_thread_mtx.GetLengths()); + + // loop over batch + for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib) + { + // do current batch of gemm + threadwise_gemm(a_thread_mtx, + True, + p_a_thread, + b_thread_mtx, + False, + p_b_thread, + c_thread_mtx, + False, + p_c_thread + ib * ThreadMatrixStrideC, + f_accum); + + // read next batch of a, b + if(BlockMatrixStrideA != 0) + { + threadwise_matrix_copy(a_block_mtx, + p_a_block + mMyThreadOffsetA + + (ib + 1) * BlockMatrixStrideA + + +k_begin * a_block_mtx.RowStride(), + a_thread_mtx, + p_a_thread, + a_thread_mtx.GetLengths()); + } + + if(BlockMatrixStrideB != 0) + { + threadwise_matrix_copy(b_block_mtx, + p_b_block + mMyThreadOffsetB + + (ib + 1) * BlockMatrixStrideB + + k_begin * b_block_mtx.RowStride(), + b_thread_mtx, + p_b_thread, + b_thread_mtx.GetLengths()); + } + } + + // do last batch of gemm + threadwise_gemm(a_thread_mtx, + True, + p_a_thread, + b_thread_mtx, + False, + p_b_thread, + c_thread_mtx, + False, + p_c_thread + (BatchPerThread - 1) * ThreadMatrixStrideC, + f_accum); + } + } } }; diff --git a/src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh b/src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh index 4b73e5a1af..1a546d6c97 100644 --- a/src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh @@ -90,13 +90,12 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc, constexpr auto out_hkwn_thread_desc = make_ConstantTensorDescriptor(Sequence{}); -#if 0 +#if 1 if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) { print_ConstantTensorDescriptor(in_nchw_block_desc, "in_nchw_block_desc"); print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc"); - print_ConstantTensorDescriptor(wei_kcsr_block_desc, "wei_kcsr_block_desc"); print_ConstantTensorDescriptor(wei_srck_block_desc, "wei_srck_block_desc"); print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc"); @@ -120,8 +119,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc, const auto c_kxwn_thread_mtx_desc = make_ConstantMatrixDescriptor( Number{}, Number{}); // constexpr doesn't compile - auto f_accum = [](auto& c, auto& ab) { c += ab; }; - const auto blockwise_batch_gemm = blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c{}; + true>{}; // LDS constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace(); @@ -183,24 +180,29 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc, __syncthreads(); +#if 1 // a series of batched GEMM for(unsigned s = 0; s < S; ++s) { for(unsigned r = 0; r < R; ++r) { + auto f_accum = [](auto& c, const auto&& ab) { c += ab; }; + blockwise_batch_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0), p_in_block + in_chwn_block_desc.Get1dIndex(0, 0, r, 0), - p_out_thread); + p_out_thread, + f_accum); } } +#endif } const auto matrix_c_index = blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; - const unsigned k_thread_data_begin = matrix_c_index.col_begin; - const unsigned wo_thread_data_begin = matrix_c_index.row_begin / NPerThread; + const unsigned k_thread_data_begin = matrix_c_index.row_begin; + const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerThread; // output: register to global mem, // convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo] @@ -216,4 +218,10 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc, wo_block_data_begin + wo_thread_data_begin), out_hkwn_thread_desc.GetLengths(), reorder_nkhw_from_hkwn); + + // printf("%f %f %f %f\n", p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]); + // printf("%u %u, %u %u %u\n", get_block_1d_id(), get_thread_local_1d_id(), + // matrix_c_index.batch_begin, matrix_c_index.row_begin, matrix_c_index.col_begin); printf("%u + // %u, %u %u %u\n", get_block_1d_id(), get_thread_local_1d_id(), ho_thread_data_begin, + // k_thread_data_begin, wo_thread_data_begin); }