diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index 36e80641f2..3ef2e036d4 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -526,7 +526,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 } } - // this version put copy and compute in same place, experimenting with compiler behaviour template __device__ void Run_v2(const FloatA* __restrict__ p_a_block, const FloatB* __restrict__ p_b_block, @@ -687,6 +686,231 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 } } + template + __device__ void Run_v3(const FloatA* __restrict__ p_a_block, + const FloatB* __restrict__ p_b_block, + FloatC* __restrict__ p_c_thread, + Accumulator f_accum) const + { + constexpr auto True = integral_constant{}; + constexpr auto False = integral_constant{}; + + constexpr auto a_block_mtx = BlockMatrixA{}; + constexpr auto b_block_mtx = BlockMatrixB{}; + constexpr auto c_thread_mtx = ThreadMatrixC{}; + + constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed + + constexpr unsigned MPerThread = c_thread_mtx.NRow(); + constexpr unsigned NPerThread = c_thread_mtx.NCol(); + + // thread A, B for GEMM + // A is transposed, b is not + constexpr auto a_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); + + constexpr auto b_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); + + // thread A-sub, B-sub for copy + constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; + FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; + + constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + + constexpr unsigned MRepeat = MPerThread / MPerThreadSubC; + constexpr unsigned NRepeat = NPerThread / NPerThreadSubC; + + // loop over k + //#pragma unroll + for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) + { + // read first batch of A, B + // copy A-sub to form A + //#pragma unroll + for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + { + for(unsigned i = 0; i < a_thread_mtx.NRow(); ++i) + { + for(unsigned j = 0; j < a_thread_mtx.NCol(); ++j) + { + p_a_thread[a_thread_mtx.Get1dIndex(i, m_repeat * MPerThreadSubC + j)] = + p_a_block[a_block_mtx.Get1dIndex(k_begin + i, + m_repeat * MPerLevel1Cluster + j) + + mMyThreadOffsetA]; + } + } + } + + // copy B-sub to form B + //#pragma unroll + for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + { + for(unsigned i = 0; i < b_thread_mtx.NRow(); ++i) + { + for(unsigned j = 0; j < b_thread_mtx.NCol(); ++j) + { + p_b_thread[b_thread_mtx.Get1dIndex(i, n_repeat * NPerThreadSubC + j)] = + p_b_block[b_block_mtx.Get1dIndex(k_begin + i, + n_repeat * MPerLevel1Cluster + j) + + mMyThreadOffsetB]; + } + } + } + + // loop over batch + //#pragma unroll + for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib) + { + // do current batch of gemm + for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) + { +#if 0 + for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) + { + for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j) + { + const unsigned aindex = + a_thread_mtx.Get1dIndex(k, i); // A is transposed + const unsigned bindex = b_thread_mtx.Get1dIndex(k, j); + const unsigned cindex = + c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC; + + f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]); + } + } +#elif 1 + static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4, + "asm is only for 16x4"); + + const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0); + for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) + { + const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed + const unsigned cindex = c_thread_mtx.Get1dIndex(i, 0); + + asm volatile("\n \ + v_mac_f32 %0, %4, %5 \n \ + v_mac_f32 %1, %4, %6 \n \ + v_mac_f32 %2, %4, %7 \n \ + v_mac_f32 %3, %4, %8 \n \ + " + : "=v"(p_c_thread[cindex + 0]), + "=v"(p_c_thread[cindex + 1]), + "=v"(p_c_thread[cindex + 2]), + "=v"(p_c_thread[cindex + 3]) + : "v"(p_a_thread[aindex]), + "v"(p_b_thread[bindex + 0]), + "v"(p_b_thread[bindex + 1]), + "v"(p_b_thread[bindex + 2]), + "v"(p_b_thread[bindex + 3]), + "0"(p_c_thread[cindex + 0]), + "1"(p_c_thread[cindex + 1]), + "2"(p_c_thread[cindex + 2]), + "3"(p_c_thread[cindex + 3])); + } +#endif + } + + // read next batch of a, b + if(BlockMatrixStrideA != 0) + { + //#pragma unroll + for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + { + for(unsigned i = 0; i < a_thread_mtx.NRow(); ++i) + { + for(unsigned j = 0; j < a_thread_mtx.NCol(); ++j) + { + p_a_thread[a_thread_mtx.Get1dIndex(i, + m_repeat * MPerThreadSubC + j)] = + p_a_block[a_block_mtx.Get1dIndex( + k_begin + i, m_repeat * MPerLevel1Cluster + j) + + (ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA]; + } + } + } + } + + if(BlockMatrixStrideB != 0) + { + //#pragma unroll + for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + { + for(unsigned i = 0; i < b_thread_mtx.NRow(); ++i) + { + for(unsigned j = 0; j < b_thread_mtx.NCol(); ++j) + { + p_b_thread[b_thread_mtx.Get1dIndex(i, + n_repeat * NPerThreadSubC + j)] = + p_b_block[b_block_mtx.Get1dIndex( + k_begin + i, n_repeat * MPerLevel1Cluster + j) + + (ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB]; + } + } + } + } + } + + // do last batch of gemm + for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) + { +#if 0 + for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) + { + for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j) + { + const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed + const unsigned bindex = b_thread_mtx.Get1dIndex(k, j); + const unsigned cindex = c_thread_mtx.Get1dIndex(i, j) + + (BatchPerThread - 1) * ThreadMatrixStrideC; + + f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]); + } + } +#elif 1 + static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4, + "asm is only for 16x4"); + + const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0); + for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) + { + const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed + const unsigned cindex = + c_thread_mtx.Get1dIndex(i, 0) + (BatchPerThread - 1) * ThreadMatrixStrideC; + + asm volatile("\n \ + v_mac_f32 %0, %4, %5 \n \ + v_mac_f32 %1, %4, %6 \n \ + v_mac_f32 %2, %4, %7 \n \ + v_mac_f32 %3, %4, %8 \n \ + " + : "=v"(p_c_thread[cindex + 0]), + "=v"(p_c_thread[cindex + 1]), + "=v"(p_c_thread[cindex + 2]), + "=v"(p_c_thread[cindex + 3]) + : "v"(p_a_thread[aindex]), + "v"(p_b_thread[bindex + 0]), + "v"(p_b_thread[bindex + 1]), + "v"(p_b_thread[bindex + 2]), + "v"(p_b_thread[bindex + 3]), + "0"(p_c_thread[cindex + 0]), + "1"(p_c_thread[cindex + 1]), + "2"(p_c_thread[cindex + 2]), + "3"(p_c_thread[cindex + 3])); + } +#endif + } + } + } + template __device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread, FloatC* __restrict__ p_c_block) const diff --git a/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp index 292a2f16eb..4dac26cab8 100644 --- a/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp @@ -209,15 +209,17 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric { for(unsigned x = 0; x < X; ++x) { -#if 1 +#if 0 blockwise_batch_gemm.Run #elif 0 blockwise_batch_gemm.Run_v2 +#elif 1 + blockwise_batch_gemm.Run_v3 #endif - (p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0), - p_out_thread, - [](auto& acc, const auto&& v) { acc += v; }); + (p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0), + p_out_thread, + [](auto& acc, const auto&& v) { acc += v; }); } } }