diff --git a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp index 3edd8253dd..7fad2713db 100644 --- a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp @@ -160,7 +160,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr unsigned WeiBlockCopyDataPerRead = 4; constexpr unsigned BlockSize = 128; -#elif 1 +#elif 0 // 1x1, 28x28, 256 thread constexpr unsigned BPerBlock = 128; constexpr unsigned KPerBlock = 128; @@ -211,7 +211,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, for(unsigned i = 0; i < nrepeat; ++i) { float time = launch_kernel( -#if 0 +#if 1 gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn #else gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index d38161c078..1ae7ecb78b 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -661,9 +661,9 @@ int main(int argc, char* argv[]) device_direct_convolution_2_nchw_kcyx_nkhw #elif 0 device_direct_convolution_2_vectorized_nchw_kcyx_nkhw -#elif 1 - device_implicit_gemm_convolution_1_chwn_cyxk_khwn #elif 0 + device_implicit_gemm_convolution_1_chwn_cyxk_khwn +#elif 1 device_implicit_gemm_convolution_2_chwn_cyxk_khwn #endif (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); diff --git a/src/include/blockwise_batched_gemm.hip.hpp b/src/include/blockwise_batched_gemm.hip.hpp new file mode 100644 index 0000000000..1218f173b3 --- /dev/null +++ b/src/include/blockwise_batched_gemm.hip.hpp @@ -0,0 +1,802 @@ +#pragma once +#include "threadwise_gemm.hip.hpp" + +template +struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC +{ + unsigned mMyThreadOffsetA = 0; + unsigned mMyThreadOffsetB = 0; + + struct MatrixIndex + { + unsigned batch; + unsigned row; + unsigned col; + }; + + __device__ Blockwise1dStridedBatchedGemmBlockABlockBThreadC() + { + constexpr auto a_block_mtx = BlockMatrixA{}; + constexpr auto b_block_mtx = BlockMatrixB{}; + + const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + mMyThreadOffsetA = c_thread_mtx_index.batch * BlockMatrixStrideA + + ((!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row, 0) + : a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row)); + + mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB + + ((!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col) + : b_block_mtx.Get1dIndex(c_thread_mtx_index.col, 0)); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: "); + print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: "); + print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: "); + + printf("%u %u, %u %u %u, %u %u\n", + get_block_1d_id(), + get_thread_local_1d_id(), + c_thread_mtx_index.batch, + c_thread_mtx_index.row, + c_thread_mtx_index.col, + mMyThreadOffsetA, + mMyThreadOffsetB); + } +#endif + } + + __device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const + { + + if(TransA && (!TransB) && (!TransC)) + { + constexpr auto a_block_mtx = BlockMatrixA{}; + constexpr auto b_block_mtx = BlockMatrixB{}; + + static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), + "wrong! k dimension not consistent!"); + + constexpr unsigned MPerBlock = a_block_mtx.NCol(); + constexpr unsigned NPerBlock = b_block_mtx.NCol(); + + constexpr auto c_thread_mtx = ThreadMatrixC{}; + + // divide thread work + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + static_assert(BatchSize % BatchPerThread == 0, "BatchSize % BatchPerThread != 0"); + static_assert(MPerBlock % MPerThread == 0, "MPerBlock % MPerThread != 0"); + static_assert(NPerBlock % NPerThread == 0, "NPerBlock % NPerThread != 0"); + + constexpr unsigned BatchThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread; + constexpr unsigned MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread; + constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; + + static_assert(BlockSize == BatchThreadWork * MThreadWork * NThreadWork, + "wrong! wrong BlockSize"); + + if(DistributeThreadAlongColumnFirst) + { + // num of operations can be reduced + const unsigned b_work_id = thread_id / (MThreadWork * NThreadWork); + unsigned itmp = thread_id - b_work_id * (MThreadWork * NThreadWork); + const unsigned m_work_id = itmp / NThreadWork; + const unsigned n_work_id = itmp - m_work_id * NThreadWork; + + return MatrixIndex{ + b_work_id * BatchPerThread, m_work_id * MPerThread, n_work_id * NPerThread}; + } + else + { + // not implemented + assert(false); + } + } + else + { + // not implemented + assert(false); + } + } + + // this should be optimized away if input is known + __device__ static MatrixIndex + GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c) + { + return MatrixIndex{batch_in_c, m_in_c, n_in_c}; + } + + template + __device__ void Run(const FloatA* __restrict__ p_a_block, + const FloatB* __restrict__ p_b_block, + FloatC* __restrict__ p_c_thread, + Accumulator f_accum) const + { + if(TransA && (!TransB) && (!TransC)) + { + constexpr auto True = integral_constant{}; + constexpr auto False = integral_constant{}; + + constexpr auto a_block_mtx = BlockMatrixA{}; + constexpr auto b_block_mtx = BlockMatrixB{}; + constexpr auto c_thread_mtx = ThreadMatrixC{}; + + constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed + + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + // a is transposed, b is not + constexpr auto a_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); + + constexpr auto b_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); + + FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; + FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; + + // loop over k + for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) + { + // read first batch of a, b + threadwise_matrix_copy(a_block_mtx, + p_a_block + mMyThreadOffsetA + + k_begin * a_block_mtx.RowStride(), + a_thread_mtx, + p_a_thread, + a_thread_mtx.GetLengths()); + + threadwise_matrix_copy(b_block_mtx, + p_b_block + mMyThreadOffsetB + + k_begin * b_block_mtx.RowStride(), + b_thread_mtx, + p_b_thread, + b_thread_mtx.GetLengths()); + + // loop over batch + for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib) + { + // do current batch of gemm + threadwise_gemm(a_thread_mtx, + True, + p_a_thread, + b_thread_mtx, + False, + p_b_thread, + c_thread_mtx, + False, + p_c_thread + ib * ThreadMatrixStrideC, + f_accum); + + // read next batch of a, b + if(BlockMatrixStrideA != 0) + { + threadwise_matrix_copy(a_block_mtx, + p_a_block + mMyThreadOffsetA + + (ib + 1) * BlockMatrixStrideA + + +k_begin * a_block_mtx.RowStride(), + a_thread_mtx, + p_a_thread, + a_thread_mtx.GetLengths()); + } + + if(BlockMatrixStrideB != 0) + { + threadwise_matrix_copy(b_block_mtx, + p_b_block + mMyThreadOffsetB + + (ib + 1) * BlockMatrixStrideB + + k_begin * b_block_mtx.RowStride(), + b_thread_mtx, + p_b_thread, + b_thread_mtx.GetLengths()); + } + } + + // do last batch of gemm + threadwise_gemm(a_thread_mtx, + True, + p_a_thread, + b_thread_mtx, + False, + p_b_thread, + c_thread_mtx, + False, + p_c_thread + (BatchPerThread - 1) * ThreadMatrixStrideC, + f_accum); + } + } + } +}; + +template +struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 +{ + unsigned mMyThreadOffsetA = 0; + unsigned mMyThreadOffsetB = 0; + + struct MatrixIndex + { + unsigned batch; + unsigned row; + unsigned col; + }; + + __device__ BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2() + { + static_assert(BatchSize % BatchPerThread == 0, + "wrong! BatchSize is not dividable by BatchPerThread"); + + constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread; + + constexpr unsigned ThreadPerLevel1Cluster = + MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster; + + static_assert(BlockSize == BatchThreadWork * ThreadPerLevel1Cluster, + "wrong! wrong blocksize\n"); + + constexpr auto a_block_mtx = BlockMatrixA{}; + constexpr auto b_block_mtx = BlockMatrixB{}; + constexpr auto c_thread_mtx = ThreadMatrixC{}; + + static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), + "wrong! K dimension not consistent\n"); + + constexpr unsigned M = a_block_mtx.NCol(); // A is transposed + constexpr unsigned N = b_block_mtx.NCol(); + constexpr unsigned K = a_block_mtx.NRow(); + + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0), + "wrong! Cannot evenly divide thread work among repeat \n"); + + constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; + constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + + static_assert((M % MRepeat == 0) && (N % NRepeat == 0), + "wrong! Cannot evenly divide work among repeat\n"); + + constexpr unsigned MPerLevel1Cluster = M / MRepeat; + constexpr unsigned NPerLevel1Cluster = N / NRepeat; + + static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) && + (NPerLevel1Cluster % NLevel1Cluster == 0), + "wrong! Cannot evenly divide work among Level1Cluster\n"); + + constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster; + constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster; + + static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) && + (NPerLevel0Cluster % NLevel0Cluster == 0), + "wrong! Cannot evenly divide work among Level0Cluster\n"); + + static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) && + (NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster), + "wrong! thread work size is wrong\n"); + + const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + mMyThreadOffsetA = c_thread_mtx_index.batch * BlockMatrixStrideA + + a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row); + + mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB + + b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: "); + print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: "); + print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: "); + + printf("%u %u, %u %u %u, %u %u\n", + get_block_1d_id(), + get_thread_local_1d_id(), + c_thread_mtx_index.batch, + c_thread_mtx_index.row, + c_thread_mtx_index.col, + mMyThreadOffsetA, + mMyThreadOffsetB); + } +#endif + } + + __device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const + { + constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread; + + constexpr unsigned ThreadPerLevel1Cluster = + MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster; + + constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster; + + unsigned batch_work_id = thread_id / ThreadPerLevel1Cluster; + unsigned cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster; + + unsigned level1_id = cluster_id / ThreadPerLevel0Cluster; + unsigned level1_m_id = level1_id / NLevel1Cluster; + unsigned level1_n_id = level1_id % NLevel1Cluster; + + unsigned level0_id = cluster_id % ThreadPerLevel0Cluster; + unsigned level0_m_id = level0_id / NLevel0Cluster; + unsigned level0_n_id = level0_id % NLevel0Cluster; + + constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster; + constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster; + + return MatrixIndex{batch_work_id * BatchPerThread, + level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC, + level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; + } + + // this should be optimized away if input is known + __device__ static MatrixIndex + GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c) + { + constexpr auto c_thread_mtx = ThreadMatrixC{}; + + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; + constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + + constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + + unsigned m_repeat = m_in_c / MPerThreadSubC; + unsigned n_repeat = n_in_c / NPerThreadSubC; + + unsigned m_in_sub_c = m_in_c % MPerThreadSubC; + unsigned n_in_sub_c = n_in_c % NPerThreadSubC; + + return MatrixIndex{batch_in_c, + m_repeat * MPerLevel1Cluster + m_in_sub_c, + n_repeat * NPerLevel1Cluster + n_in_sub_c}; + } + + template + __device__ void Run(const FloatA* __restrict__ p_a_block, + const FloatB* __restrict__ p_b_block, + FloatC* __restrict__ p_c_thread, + Accumulator f_accum) const + { + constexpr auto True = integral_constant{}; + constexpr auto False = integral_constant{}; + + constexpr auto a_block_mtx = BlockMatrixA{}; + constexpr auto b_block_mtx = BlockMatrixB{}; + constexpr auto c_thread_mtx = ThreadMatrixC{}; + + constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed + + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + // thread A, B for GEMM + // A is transposed, b is not + constexpr auto a_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); + + constexpr auto b_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); + + // thread A-sub, B-sub for copy + constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; + FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; + + constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + + constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; + constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + +// loop over k +#pragma unroll + for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) + { +// read first batch of A, B +// copy A-sub to form A +#pragma unroll + for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + { + threadwise_matrix_copy( + a_block_mtx, + p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) + + mMyThreadOffsetA, + a_thread_mtx, + p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), + a_thread_sub_mtx.GetLengths()); + } + +// copy B-sub to form B +#pragma unroll + for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + { + threadwise_matrix_copy( + b_block_mtx, + p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) + + mMyThreadOffsetB, + b_thread_mtx, + p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), + b_thread_sub_mtx.GetLengths()); + } + +// loop over batch +#pragma unroll + for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib) + { + // do current batch of gemm + threadwise_gemm(a_thread_mtx, + True, + p_a_thread, + b_thread_mtx, + False, + p_b_thread, + c_thread_mtx, + False, + p_c_thread + ib * ThreadMatrixStrideC, + f_accum); + + // read next batch of a, b + if(BlockMatrixStrideA != 0) + { +#pragma unroll + for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + { + threadwise_matrix_copy( + a_block_mtx, + p_a_block + + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) + + (ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA, + a_thread_mtx, + p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), + a_thread_sub_mtx.GetLengths()); + } + } + + if(BlockMatrixStrideB != 0) + { +#pragma unroll + for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + { + threadwise_matrix_copy( + b_block_mtx, + p_b_block + + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) + + (ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB, + b_thread_mtx, + p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), + b_thread_sub_mtx.GetLengths()); + } + } + } + + // do last batch of gemm + threadwise_gemm(a_thread_mtx, + True, + p_a_thread, + b_thread_mtx, + False, + p_b_thread, + c_thread_mtx, + False, + p_c_thread + (BatchPerThread - 1) * ThreadMatrixStrideC, + f_accum); + } + } + + template + __device__ void Run_v3(const FloatA* __restrict__ p_a_block, + const FloatB* __restrict__ p_b_block, + FloatC* __restrict__ p_c_thread, + Accumulator f_accum) const + { + constexpr auto True = integral_constant{}; + constexpr auto False = integral_constant{}; + + constexpr auto a_block_mtx = BlockMatrixA{}; + constexpr auto b_block_mtx = BlockMatrixB{}; + constexpr auto c_thread_mtx = ThreadMatrixC{}; + + constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed + + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + // thread A, B for GEMM + // A is transposed, b is not + constexpr auto a_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); + + constexpr auto b_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); + + // thread A-sub, B-sub for copy + constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; + FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; + + constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + + constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; + constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + + // loop over k + //#pragma unroll + for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) + { + // read first batch of A, B + // copy A-sub to form A + //#pragma unroll + for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + { + for(unsigned i = 0; i < a_thread_sub_mtx.NRow(); ++i) + { +#if 1 + for(unsigned j = 0; j < a_thread_sub_mtx.NCol(); ++j) + { + p_a_thread[a_thread_mtx.Get1dIndex(i, m_repeat * MPerThreadSubC + j)] = + p_a_block[a_block_mtx.Get1dIndex(k_begin + i, + m_repeat * MPerLevel1Cluster + j) + + mMyThreadOffsetA]; + } +#else + static_assert(a_thread_sub_mtx.NCol() == 4, "asm only read 4xfp32"); + +#endif + } + } + + // copy B-sub to form B + //#pragma unroll + for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + { + for(unsigned i = 0; i < b_thread_sub_mtx.NRow(); ++i) + { + for(unsigned j = 0; j < b_thread_sub_mtx.NCol(); ++j) + { + p_b_thread[b_thread_mtx.Get1dIndex(i, n_repeat * NPerThreadSubC + j)] = + p_b_block[b_block_mtx.Get1dIndex(k_begin + i, + n_repeat * MPerLevel1Cluster + j) + + mMyThreadOffsetB]; + } + } + } + + // loop over batch + //#pragma unroll + for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib) + { + // do current batch of gemm + for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) + { +#if 0 + for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) + { + for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j) + { + const unsigned aindex = + a_thread_mtx.Get1dIndex(k, i); // A is transposed + const unsigned bindex = b_thread_mtx.Get1dIndex(k, j); + const unsigned cindex = + c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC; + + f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]); + } + } +#elif 1 + static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4, + "asm is only for 16x4"); + + const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0); + for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) + { + const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed + const unsigned cindex = c_thread_mtx.Get1dIndex(i, 0); + + asm volatile("\n \ + v_mac_f32 %0, %4, %5 \n \ + v_mac_f32 %1, %4, %6 \n \ + v_mac_f32 %2, %4, %7 \n \ + v_mac_f32 %3, %4, %8 \n \ + " + : "=v"(p_c_thread[cindex + 0]), + "=v"(p_c_thread[cindex + 1]), + "=v"(p_c_thread[cindex + 2]), + "=v"(p_c_thread[cindex + 3]) + : "v"(p_a_thread[aindex]), + "v"(p_b_thread[bindex + 0]), + "v"(p_b_thread[bindex + 1]), + "v"(p_b_thread[bindex + 2]), + "v"(p_b_thread[bindex + 3]), + "0"(p_c_thread[cindex + 0]), + "1"(p_c_thread[cindex + 1]), + "2"(p_c_thread[cindex + 2]), + "3"(p_c_thread[cindex + 3])); + } +#endif + } + + // read next batch of a, b + if(BlockMatrixStrideA != 0) + { + //#pragma unroll + for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + { + for(unsigned i = 0; i < a_thread_sub_mtx.NRow(); ++i) + { + for(unsigned j = 0; j < a_thread_sub_mtx.NCol(); ++j) + { + p_a_thread[a_thread_mtx.Get1dIndex(i, + m_repeat * MPerThreadSubC + j)] = + p_a_block[a_block_mtx.Get1dIndex( + k_begin + i, m_repeat * MPerLevel1Cluster + j) + + (ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA]; + } + } + } + } + + if(BlockMatrixStrideB != 0) + { + //#pragma unroll + for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + { + for(unsigned i = 0; i < b_thread_sub_mtx.NRow(); ++i) + { + for(unsigned j = 0; j < b_thread_sub_mtx.NCol(); ++j) + { + p_b_thread[b_thread_mtx.Get1dIndex(i, + n_repeat * NPerThreadSubC + j)] = + p_b_block[b_block_mtx.Get1dIndex( + k_begin + i, n_repeat * MPerLevel1Cluster + j) + + (ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB]; + } + } + } + } + } + + // do last batch of gemm + for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) + { +#if 0 + for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) + { + for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j) + { + const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed + const unsigned bindex = b_thread_mtx.Get1dIndex(k, j); + const unsigned cindex = c_thread_mtx.Get1dIndex(i, j) + + (BatchPerThread - 1) * ThreadMatrixStrideC; + + f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]); + } + } +#elif 1 + static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4, + "asm is only for 16x4"); + + const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0); + for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) + { + const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed + const unsigned cindex = + c_thread_mtx.Get1dIndex(i, 0) + (BatchPerThread - 1) * ThreadMatrixStrideC; + + asm volatile("\n \ + v_mac_f32 %0, %4, %5 \n \ + v_mac_f32 %1, %4, %6 \n \ + v_mac_f32 %2, %4, %7 \n \ + v_mac_f32 %3, %4, %8 \n \ + " + : "=v"(p_c_thread[cindex + 0]), + "=v"(p_c_thread[cindex + 1]), + "=v"(p_c_thread[cindex + 2]), + "=v"(p_c_thread[cindex + 3]) + : "v"(p_a_thread[aindex]), + "v"(p_b_thread[bindex + 0]), + "v"(p_b_thread[bindex + 1]), + "v"(p_b_thread[bindex + 2]), + "v"(p_b_thread[bindex + 3]), + "0"(p_c_thread[cindex + 0]), + "1"(p_c_thread[cindex + 1]), + "2"(p_c_thread[cindex + 2]), + "3"(p_c_thread[cindex + 3])); + } +#endif + } + } + } + + template + __device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread, + FloatC* __restrict__ p_c_block) const + { + constexpr auto c_block_mtx = BlockMatrixC{}; + constexpr auto c_thread_mtx = ThreadMatrixC{}; + + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + + constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; + constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + + const auto c_thread_mtx_begin = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const unsigned c_thread_offset = + c_thread_mtx_begin.batch * BlockMatrixStrideC + + c_block_mtx.Get1dIndex(c_thread_mtx_begin.row, c_thread_mtx_begin.col); + + for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + { + for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + { + threadwise_matrix_copy( + c_thread_sub_mtx, + p_c_thread + + c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster, + n_repeat * NPerLevel1Cluster), + c_block_mtx, + p_c_block + + c_block_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster, + n_repeat * NPerLevel1Cluster) + + c_thread_offset, + c_thread_sub_mtx.GetLengths()); + } + } + } +}; diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index 3ef2e036d4..9471776a74 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -1,961 +1,6 @@ #pragma once #include "threadwise_gemm.hip.hpp" -template -struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC -{ - unsigned mMyThreadOffsetA = 0; - unsigned mMyThreadOffsetB = 0; - - struct MatrixIndex - { - unsigned batch; - unsigned row; - unsigned col; - }; - - __device__ Blockwise1dStridedBatchedGemmBlockABlockBThreadC() - { - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - - const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - mMyThreadOffsetA = c_thread_mtx_index.batch * BlockMatrixStrideA + - ((!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row, 0) - : a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row)); - - mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB + - ((!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col) - : b_block_mtx.Get1dIndex(c_thread_mtx_index.col, 0)); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: "); - print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: "); - print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: "); - - printf("%u %u, %u %u %u, %u %u\n", - get_block_1d_id(), - get_thread_local_1d_id(), - c_thread_mtx_index.batch, - c_thread_mtx_index.row, - c_thread_mtx_index.col, - mMyThreadOffsetA, - mMyThreadOffsetB); - } -#endif - } - - __device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const - { - - if(TransA && (!TransB) && (!TransC)) - { - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - - static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), - "wrong! k dimension not consistent!"); - - constexpr unsigned MPerBlock = a_block_mtx.NCol(); - constexpr unsigned NPerBlock = b_block_mtx.NCol(); - - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - // divide thread work - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); - - static_assert(BatchSize % BatchPerThread == 0, "BatchSize % BatchPerThread != 0"); - static_assert(MPerBlock % MPerThread == 0, "MPerBlock % MPerThread != 0"); - static_assert(NPerBlock % NPerThread == 0, "NPerBlock % NPerThread != 0"); - - constexpr unsigned BatchThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread; - constexpr unsigned MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread; - constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; - - static_assert(BlockSize == BatchThreadWork * MThreadWork * NThreadWork, - "wrong! wrong BlockSize"); - - if(DistributeThreadAlongColumnFirst) - { - // num of operations can be reduced - const unsigned b_work_id = thread_id / (MThreadWork * NThreadWork); - unsigned itmp = thread_id - b_work_id * (MThreadWork * NThreadWork); - const unsigned m_work_id = itmp / NThreadWork; - const unsigned n_work_id = itmp - m_work_id * NThreadWork; - - return MatrixIndex{ - b_work_id * BatchPerThread, m_work_id * MPerThread, n_work_id * NPerThread}; - } - else - { - // not implemented - assert(false); - } - } - else - { - // not implemented - assert(false); - } - } - - // this should be optimized away if input is known - __device__ static MatrixIndex - GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c) - { - return MatrixIndex{batch_in_c, m_in_c, n_in_c}; - } - - template - __device__ void Run(const FloatA* __restrict__ p_a_block, - const FloatB* __restrict__ p_b_block, - FloatC* __restrict__ p_c_thread, - Accumulator f_accum) const - { - if(TransA && (!TransB) && (!TransC)) - { - constexpr auto True = integral_constant{}; - constexpr auto False = integral_constant{}; - - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed - - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); - - // a is transposed, b is not - constexpr auto a_thread_mtx = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - constexpr auto b_thread_mtx = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; - FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; - - // loop over k - for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) - { - // read first batch of a, b - threadwise_matrix_copy(a_block_mtx, - p_a_block + mMyThreadOffsetA + - k_begin * a_block_mtx.RowStride(), - a_thread_mtx, - p_a_thread, - a_thread_mtx.GetLengths()); - - threadwise_matrix_copy(b_block_mtx, - p_b_block + mMyThreadOffsetB + - k_begin * b_block_mtx.RowStride(), - b_thread_mtx, - p_b_thread, - b_thread_mtx.GetLengths()); - - // loop over batch - for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib) - { - // do current batch of gemm - threadwise_gemm(a_thread_mtx, - True, - p_a_thread, - b_thread_mtx, - False, - p_b_thread, - c_thread_mtx, - False, - p_c_thread + ib * ThreadMatrixStrideC, - f_accum); - - // read next batch of a, b - if(BlockMatrixStrideA != 0) - { - threadwise_matrix_copy(a_block_mtx, - p_a_block + mMyThreadOffsetA + - (ib + 1) * BlockMatrixStrideA + - +k_begin * a_block_mtx.RowStride(), - a_thread_mtx, - p_a_thread, - a_thread_mtx.GetLengths()); - } - - if(BlockMatrixStrideB != 0) - { - threadwise_matrix_copy(b_block_mtx, - p_b_block + mMyThreadOffsetB + - (ib + 1) * BlockMatrixStrideB + - k_begin * b_block_mtx.RowStride(), - b_thread_mtx, - p_b_thread, - b_thread_mtx.GetLengths()); - } - } - - // do last batch of gemm - threadwise_gemm(a_thread_mtx, - True, - p_a_thread, - b_thread_mtx, - False, - p_b_thread, - c_thread_mtx, - False, - p_c_thread + (BatchPerThread - 1) * ThreadMatrixStrideC, - f_accum); - } - } - } -}; - -template -struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 -{ - unsigned mMyThreadOffsetA = 0; - unsigned mMyThreadOffsetB = 0; - - struct MatrixIndex - { - unsigned batch; - unsigned row; - unsigned col; - }; - - __device__ BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2() - { - static_assert(BatchSize % BatchPerThread == 0, - "wrong! BatchSize is not dividable by BatchPerThread"); - - constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread; - - constexpr unsigned ThreadPerLevel1Cluster = - MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster; - - static_assert(BlockSize == BatchThreadWork * ThreadPerLevel1Cluster, - "wrong! wrong blocksize\n"); - - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), - "wrong! K dimension not consistent\n"); - - constexpr unsigned M = a_block_mtx.NCol(); // A is transposed - constexpr unsigned N = b_block_mtx.NCol(); - constexpr unsigned K = a_block_mtx.NRow(); - - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); - - static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0), - "wrong! Cannot evenly divide thread work among repeat \n"); - - constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; - constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; - - static_assert((M % MRepeat == 0) && (N % NRepeat == 0), - "wrong! Cannot evenly divide work among repeat\n"); - - constexpr unsigned MPerLevel1Cluster = M / MRepeat; - constexpr unsigned NPerLevel1Cluster = N / NRepeat; - - static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) && - (NPerLevel1Cluster % NLevel1Cluster == 0), - "wrong! Cannot evenly divide work among Level1Cluster\n"); - - constexpr unsigned MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster; - constexpr unsigned NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster; - - static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) && - (NPerLevel0Cluster % NLevel0Cluster == 0), - "wrong! Cannot evenly divide work among Level0Cluster\n"); - - static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) && - (NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster), - "wrong! thread work size is wrong\n"); - - const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - mMyThreadOffsetA = c_thread_mtx_index.batch * BlockMatrixStrideA + - a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row); - - mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB + - b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: "); - print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: "); - print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: "); - - printf("%u %u, %u %u %u, %u %u\n", - get_block_1d_id(), - get_thread_local_1d_id(), - c_thread_mtx_index.batch, - c_thread_mtx_index.row, - c_thread_mtx_index.col, - mMyThreadOffsetA, - mMyThreadOffsetB); - } -#endif - } - - __device__ MatrixIndex GetBeginOfThreadMatrixC(unsigned thread_id) const - { - constexpr unsigned BatchThreadWork = BatchSize / BatchPerThread; - - constexpr unsigned ThreadPerLevel1Cluster = - MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster; - - constexpr unsigned ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster; - - unsigned batch_work_id = thread_id / ThreadPerLevel1Cluster; - unsigned cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster; - - unsigned level1_id = cluster_id / ThreadPerLevel0Cluster; - unsigned level1_m_id = level1_id / NLevel1Cluster; - unsigned level1_n_id = level1_id % NLevel1Cluster; - - unsigned level0_id = cluster_id % ThreadPerLevel0Cluster; - unsigned level0_m_id = level0_id / NLevel0Cluster; - unsigned level0_n_id = level0_id % NLevel0Cluster; - - constexpr unsigned MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster; - constexpr unsigned NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster; - - return MatrixIndex{batch_work_id * BatchPerThread, - level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC, - level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; - } - - // this should be optimized away if input is known - __device__ static MatrixIndex - GetDistanceFromBeginOfThreadMatrixC(unsigned batch_in_c, unsigned m_in_c, unsigned n_in_c) - { - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); - - constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; - constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; - - constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - - unsigned m_repeat = m_in_c / MPerThreadSubC; - unsigned n_repeat = n_in_c / NPerThreadSubC; - - unsigned m_in_sub_c = m_in_c % MPerThreadSubC; - unsigned n_in_sub_c = n_in_c % NPerThreadSubC; - - return MatrixIndex{batch_in_c, - m_repeat * MPerLevel1Cluster + m_in_sub_c, - n_repeat * NPerLevel1Cluster + n_in_sub_c}; - } - - template - __device__ void Run(const FloatA* __restrict__ p_a_block, - const FloatB* __restrict__ p_b_block, - FloatC* __restrict__ p_c_thread, - Accumulator f_accum) const - { - constexpr auto True = integral_constant{}; - constexpr auto False = integral_constant{}; - - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed - - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); - - // thread A, B for GEMM - // A is transposed, b is not - constexpr auto a_thread_mtx = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - constexpr auto b_thread_mtx = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - // thread A-sub, B-sub for copy - constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; - FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; - - constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - - constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; - constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; - -// loop over k -#pragma unroll - for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) - { -// read first batch of A, B -// copy A-sub to form A -#pragma unroll - for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) - { - threadwise_matrix_copy( - a_block_mtx, - p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) + - mMyThreadOffsetA, - a_thread_mtx, - p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), - a_thread_sub_mtx.GetLengths()); - } - -// copy B-sub to form B -#pragma unroll - for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) - { - threadwise_matrix_copy( - b_block_mtx, - p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) + - mMyThreadOffsetB, - b_thread_mtx, - p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), - b_thread_sub_mtx.GetLengths()); - } - -// loop over batch -#pragma unroll - for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib) - { - // do current batch of gemm - threadwise_gemm(a_thread_mtx, - True, - p_a_thread, - b_thread_mtx, - False, - p_b_thread, - c_thread_mtx, - False, - p_c_thread + ib * ThreadMatrixStrideC, - f_accum); - - // read next batch of a, b - if(BlockMatrixStrideA != 0) - { -#pragma unroll - for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) - { - threadwise_matrix_copy( - a_block_mtx, - p_a_block + - a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) + - (ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA, - a_thread_mtx, - p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), - a_thread_sub_mtx.GetLengths()); - } - } - - if(BlockMatrixStrideB != 0) - { -#pragma unroll - for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) - { - threadwise_matrix_copy( - b_block_mtx, - p_b_block + - b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) + - (ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB, - b_thread_mtx, - p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), - b_thread_sub_mtx.GetLengths()); - } - } - } - - // do last batch of gemm - threadwise_gemm(a_thread_mtx, - True, - p_a_thread, - b_thread_mtx, - False, - p_b_thread, - c_thread_mtx, - False, - p_c_thread + (BatchPerThread - 1) * ThreadMatrixStrideC, - f_accum); - } - } - - template - __device__ void Run_v2(const FloatA* __restrict__ p_a_block, - const FloatB* __restrict__ p_b_block, - FloatC* __restrict__ p_c_thread, - Accumulator f_accum) const - { - constexpr auto True = integral_constant{}; - constexpr auto False = integral_constant{}; - - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed - - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); - - // thread A, B for GEMM - // A is transposed, b is not - constexpr auto a_thread_mtx = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - constexpr auto b_thread_mtx = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - // thread A-sub, B-sub for copy - constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; - FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; - - constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - - constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; - constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; - - // loop over k - //#pragma unroll - for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) - { - // read first batch of A, B - // copy A-sub to form A - //#pragma unroll - for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) - { - for(unsigned i = 0; i < a_thread_mtx.NRow(); ++i) - { - for(unsigned j = 0; j < a_thread_mtx.NCol(); ++j) - { - p_a_thread[a_thread_mtx.Get1dIndex(i, m_repeat * MPerThreadSubC + j)] = - p_a_block[a_block_mtx.Get1dIndex(k_begin + i, - m_repeat * MPerLevel1Cluster + j) + - mMyThreadOffsetA]; - } - } - } - - // copy B-sub to form B - //#pragma unroll - for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) - { - for(unsigned i = 0; i < b_thread_mtx.NRow(); ++i) - { - for(unsigned j = 0; j < b_thread_mtx.NCol(); ++j) - { - p_b_thread[b_thread_mtx.Get1dIndex(i, n_repeat * NPerThreadSubC + j)] = - p_b_block[b_block_mtx.Get1dIndex(k_begin + i, - n_repeat * MPerLevel1Cluster + j) + - mMyThreadOffsetB]; - } - } - } - - // loop over batch - //#pragma unroll - for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib) - { - // do current batch of gemm - for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) - { - for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) - { - for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j) - { - const unsigned aindex = - a_thread_mtx.Get1dIndex(k, i); // A is transposed - const unsigned bindex = b_thread_mtx.Get1dIndex(k, j); - const unsigned cindex = - c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC; - - f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]); - } - } - } - - // read next batch of a, b - if(BlockMatrixStrideA != 0) - { - //#pragma unroll - for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) - { - for(unsigned i = 0; i < a_thread_mtx.NRow(); ++i) - { - for(unsigned j = 0; j < a_thread_mtx.NCol(); ++j) - { - p_a_thread[a_thread_mtx.Get1dIndex(i, - m_repeat * MPerThreadSubC + j)] = - p_a_block[a_block_mtx.Get1dIndex( - k_begin + i, m_repeat * MPerLevel1Cluster + j) + - (ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA]; - } - } - } - } - - if(BlockMatrixStrideB != 0) - { - //#pragma unroll - for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) - { - for(unsigned i = 0; i < b_thread_mtx.NRow(); ++i) - { - for(unsigned j = 0; j < b_thread_mtx.NCol(); ++j) - { - p_b_thread[b_thread_mtx.Get1dIndex(i, - n_repeat * NPerThreadSubC + j)] = - p_b_block[b_block_mtx.Get1dIndex( - k_begin + i, n_repeat * MPerLevel1Cluster + j) + - (ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB]; - } - } - } - } - } - - // do last batch of gemm - for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) - { - for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) - { - for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j) - { - const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed - const unsigned bindex = b_thread_mtx.Get1dIndex(k, j); - const unsigned cindex = c_thread_mtx.Get1dIndex(i, j) + - (BatchPerThread - 1) * ThreadMatrixStrideC; - - f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]); - } - } - } - } - } - - template - __device__ void Run_v3(const FloatA* __restrict__ p_a_block, - const FloatB* __restrict__ p_b_block, - FloatC* __restrict__ p_c_thread, - Accumulator f_accum) const - { - constexpr auto True = integral_constant{}; - constexpr auto False = integral_constant{}; - - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed - - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); - - // thread A, B for GEMM - // A is transposed, b is not - constexpr auto a_thread_mtx = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - constexpr auto b_thread_mtx = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - // thread A-sub, B-sub for copy - constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; - FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; - - constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - - constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; - constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; - - // loop over k - //#pragma unroll - for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) - { - // read first batch of A, B - // copy A-sub to form A - //#pragma unroll - for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) - { - for(unsigned i = 0; i < a_thread_mtx.NRow(); ++i) - { - for(unsigned j = 0; j < a_thread_mtx.NCol(); ++j) - { - p_a_thread[a_thread_mtx.Get1dIndex(i, m_repeat * MPerThreadSubC + j)] = - p_a_block[a_block_mtx.Get1dIndex(k_begin + i, - m_repeat * MPerLevel1Cluster + j) + - mMyThreadOffsetA]; - } - } - } - - // copy B-sub to form B - //#pragma unroll - for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) - { - for(unsigned i = 0; i < b_thread_mtx.NRow(); ++i) - { - for(unsigned j = 0; j < b_thread_mtx.NCol(); ++j) - { - p_b_thread[b_thread_mtx.Get1dIndex(i, n_repeat * NPerThreadSubC + j)] = - p_b_block[b_block_mtx.Get1dIndex(k_begin + i, - n_repeat * MPerLevel1Cluster + j) + - mMyThreadOffsetB]; - } - } - } - - // loop over batch - //#pragma unroll - for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib) - { - // do current batch of gemm - for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) - { -#if 0 - for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) - { - for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j) - { - const unsigned aindex = - a_thread_mtx.Get1dIndex(k, i); // A is transposed - const unsigned bindex = b_thread_mtx.Get1dIndex(k, j); - const unsigned cindex = - c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC; - - f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]); - } - } -#elif 1 - static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4, - "asm is only for 16x4"); - - const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0); - for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) - { - const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed - const unsigned cindex = c_thread_mtx.Get1dIndex(i, 0); - - asm volatile("\n \ - v_mac_f32 %0, %4, %5 \n \ - v_mac_f32 %1, %4, %6 \n \ - v_mac_f32 %2, %4, %7 \n \ - v_mac_f32 %3, %4, %8 \n \ - " - : "=v"(p_c_thread[cindex + 0]), - "=v"(p_c_thread[cindex + 1]), - "=v"(p_c_thread[cindex + 2]), - "=v"(p_c_thread[cindex + 3]) - : "v"(p_a_thread[aindex]), - "v"(p_b_thread[bindex + 0]), - "v"(p_b_thread[bindex + 1]), - "v"(p_b_thread[bindex + 2]), - "v"(p_b_thread[bindex + 3]), - "0"(p_c_thread[cindex + 0]), - "1"(p_c_thread[cindex + 1]), - "2"(p_c_thread[cindex + 2]), - "3"(p_c_thread[cindex + 3])); - } -#endif - } - - // read next batch of a, b - if(BlockMatrixStrideA != 0) - { - //#pragma unroll - for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) - { - for(unsigned i = 0; i < a_thread_mtx.NRow(); ++i) - { - for(unsigned j = 0; j < a_thread_mtx.NCol(); ++j) - { - p_a_thread[a_thread_mtx.Get1dIndex(i, - m_repeat * MPerThreadSubC + j)] = - p_a_block[a_block_mtx.Get1dIndex( - k_begin + i, m_repeat * MPerLevel1Cluster + j) + - (ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA]; - } - } - } - } - - if(BlockMatrixStrideB != 0) - { - //#pragma unroll - for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) - { - for(unsigned i = 0; i < b_thread_mtx.NRow(); ++i) - { - for(unsigned j = 0; j < b_thread_mtx.NCol(); ++j) - { - p_b_thread[b_thread_mtx.Get1dIndex(i, - n_repeat * NPerThreadSubC + j)] = - p_b_block[b_block_mtx.Get1dIndex( - k_begin + i, n_repeat * MPerLevel1Cluster + j) + - (ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB]; - } - } - } - } - } - - // do last batch of gemm - for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) - { -#if 0 - for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) - { - for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j) - { - const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed - const unsigned bindex = b_thread_mtx.Get1dIndex(k, j); - const unsigned cindex = c_thread_mtx.Get1dIndex(i, j) + - (BatchPerThread - 1) * ThreadMatrixStrideC; - - f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]); - } - } -#elif 1 - static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4, - "asm is only for 16x4"); - - const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0); - for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) - { - const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed - const unsigned cindex = - c_thread_mtx.Get1dIndex(i, 0) + (BatchPerThread - 1) * ThreadMatrixStrideC; - - asm volatile("\n \ - v_mac_f32 %0, %4, %5 \n \ - v_mac_f32 %1, %4, %6 \n \ - v_mac_f32 %2, %4, %7 \n \ - v_mac_f32 %3, %4, %8 \n \ - " - : "=v"(p_c_thread[cindex + 0]), - "=v"(p_c_thread[cindex + 1]), - "=v"(p_c_thread[cindex + 2]), - "=v"(p_c_thread[cindex + 3]) - : "v"(p_a_thread[aindex]), - "v"(p_b_thread[bindex + 0]), - "v"(p_b_thread[bindex + 1]), - "v"(p_b_thread[bindex + 2]), - "v"(p_b_thread[bindex + 3]), - "0"(p_c_thread[cindex + 0]), - "1"(p_c_thread[cindex + 1]), - "2"(p_c_thread[cindex + 2]), - "3"(p_c_thread[cindex + 3])); - } -#endif - } - } - } - - template - __device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread, - FloatC* __restrict__ p_c_block) const - { - constexpr auto c_block_mtx = BlockMatrixC{}; - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - constexpr unsigned MPerThread = c_thread_mtx.NRow(); - constexpr unsigned NPerThread = c_thread_mtx.NCol(); - - constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - - constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; - constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; - - const auto c_thread_mtx_begin = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const unsigned c_thread_offset = - c_thread_mtx_begin.batch * BlockMatrixStrideC + - c_block_mtx.Get1dIndex(c_thread_mtx_begin.row, c_thread_mtx_begin.col); - - for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) - { - for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) - { - threadwise_matrix_copy( - c_thread_sub_mtx, - p_c_thread + - c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster, - n_repeat * NPerLevel1Cluster), - c_block_mtx, - p_c_block + - c_block_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster, - n_repeat * NPerLevel1Cluster) + - c_thread_offset, - c_thread_sub_mtx.GetLengths()); - } - } - } -}; - template + __device__ void Run_asm(const FloatA* __restrict__ p_a_block, + const FloatB* __restrict__ p_b_block, + FloatC* __restrict__ p_c_thread, + Accumulator f_accum) const + { + constexpr auto True = integral_constant{}; + constexpr auto False = integral_constant{}; + + constexpr auto a_block_mtx = BlockMatrixA{}; + constexpr auto b_block_mtx = BlockMatrixB{}; + constexpr auto c_thread_mtx = ThreadMatrixC{}; + + constexpr unsigned M = a_block_mtx.NCol(); + constexpr unsigned N = b_block_mtx.NCol(); + constexpr unsigned K = a_block_mtx.NRow(); + + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + // thread A, B for GEMM + constexpr auto a_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); + + constexpr auto b_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); + + // thread A-sub, B-sub for copy + constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; + FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; + + constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + + constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; + constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + +#pragma unroll + // loop over k + for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) + { +#pragma unroll + // copy A-sub to form A + for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + { + threadwise_matrix_copy( + a_block_mtx, + p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) + + mMyThreadOffsetA, + a_thread_mtx, + p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), + a_thread_sub_mtx.GetLengths()); + } + +#pragma unroll + // copy B-sub to form B + for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + { + threadwise_matrix_copy( + b_block_mtx, + p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) + + mMyThreadOffsetB, + b_thread_mtx, + p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), + b_thread_sub_mtx.GetLengths()); + } + +// C = A * B +#if 1 + threadwise_gemm(a_thread_mtx, + True, + p_a_thread, + b_thread_mtx, + False, + p_b_thread, + c_thread_mtx, + False, + p_c_thread, + f_accum); +#else + // inline asm + static_assert(c_thread_mtx.NRow() == 8 && c_thread_mtx.NCol() == 8, + "asm is only for 8x8"); + + for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) // A is transposed + { + const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0); + + for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) + { + const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed + const unsigned cindex = c_thread_mtx.Get1dIndex(i, 0); + + asm volatile("\n \ + v_mac_f32 %0, %8, %9 \n \ + v_mac_f32 %1, %8, %10 \n \ + v_mac_f32 %2, %8, %11 \n \ + v_mac_f32 %3, %8, %12 \n \ + v_mac_f32 %4, %8, %13 \n \ + v_mac_f32 %5, %8, %14 \n \ + v_mac_f32 %6, %8, %15 \n \ + v_mac_f32 %7, %8, %16 \n \ + " + : "=v"(p_c_thread[cindex + 0]), + "=v"(p_c_thread[cindex + 1]), + "=v"(p_c_thread[cindex + 2]), + "=v"(p_c_thread[cindex + 3]), + "=v"(p_c_thread[cindex + 4]), + "=v"(p_c_thread[cindex + 5]), + "=v"(p_c_thread[cindex + 6]), + "=v"(p_c_thread[cindex + 7]) + : "v"(p_a_thread[aindex]), + "v"(p_b_thread[bindex + 0]), + "v"(p_b_thread[bindex + 1]), + "v"(p_b_thread[bindex + 2]), + "v"(p_b_thread[bindex + 3]), + "v"(p_b_thread[bindex + 4]), + "v"(p_b_thread[bindex + 5]), + "v"(p_b_thread[bindex + 6]), + "v"(p_b_thread[bindex + 7]), + "0"(p_c_thread[cindex + 0]), + "1"(p_c_thread[cindex + 1]), + "2"(p_c_thread[cindex + 2]), + "3"(p_c_thread[cindex + 3]), + "4"(p_c_thread[cindex + 4]), + "5"(p_c_thread[cindex + 5]), + "6"(p_c_thread[cindex + 6]), + "7"(p_c_thread[cindex + 7])); + } + } +#endif + } + } + template __device__ void Run_RegisterDoubleBuffer(FloatA* const p_a_block, FloatB* const p_b_block, diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp index 4dac26cab8..c13eabc8f3 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp @@ -6,7 +6,7 @@ #include "blockwise_2d_tensor_op.hip.hpp" #include "threadwise_nd_tensor_op.hip.hpp" #include "threadwise_4d_tensor_op.hip.hpp" -#include "blockwise_gemm.hip.hpp" +#include "blockwise_batched_gemm.hip.hpp" template -__global__ void gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( +__global__ void +#if 0 +__launch_bounds__(256,2) +#endif +gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_wei_global, Float* const __restrict__ p_out_global) @@ -280,15 +284,15 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_b for(unsigned x = 0; x < X; ++x) { auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; -#if 1 +#if 0 blockwise_gemm.Run #else blockwise_gemm.Run_RegisterDoubleBuffer #endif - (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block_now + y * Wi + x, - p_out_thread, - f_accum); + (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block_now + y * Wi + x, + p_out_thread, + f_accum); } }