diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index 221a7153a2..f80c49a029 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -435,11 +435,12 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 #pragma unroll for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) { -// read first batch of A, B -// copy A-sub to form A -#pragma unroll + // read first batch of A, B + // copy A-sub to form A + //#pragma unroll for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) { +#if 0 threadwise_matrix_copy( a_block_mtx, p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) + @@ -447,12 +448,25 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 a_thread_mtx, p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), a_thread_sub_mtx.GetLengths()); +#else + for(unsigned i = 0; i < a_thread_mtx.NRow(); ++i) + { + for(unsigned j = 0; j < a_thread_mtx.NCol(); ++j) + { + p_a_thread[a_thread_mtx.Get1dIndex(i, m_repeat * MPerThreadSubC + j)] = + p_a_block[a_block_mtx.Get1dIndex(k_begin + i, + m_repeat * MPerLevel1Cluster + j) + + mMyThreadOffsetA]; + } + } +#endif } -// copy B-sub to form B -#pragma unroll + // copy B-sub to form B + //#pragma unroll for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) { +#if 0 threadwise_matrix_copy( b_block_mtx, p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) + @@ -460,13 +474,26 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 b_thread_mtx, p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), b_thread_sub_mtx.GetLengths()); +#else + for(unsigned i = 0; i < b_thread_mtx.NRow(); ++i) + { + for(unsigned j = 0; j < b_thread_mtx.NCol(); ++j) + { + p_b_thread[b_thread_mtx.Get1dIndex(i, n_repeat * NPerThreadSubC + j)] = + p_b_block[b_block_mtx.Get1dIndex(k_begin + i, + n_repeat * MPerLevel1Cluster + j) + + mMyThreadOffsetB]; + } + } +#endif } -// loop over batch -#pragma unroll + // loop over batch + //#pragma unroll for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib) { - // do current batch of gemm +// do current batch of gemm +#if 0 threadwise_gemm(a_thread_mtx, True, p_a_thread, @@ -477,13 +504,32 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 False, p_c_thread + ib * ThreadMatrixStrideC, f_accum); +#else + for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) + { + for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) + { + for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j) + { + const unsigned aindex = + a_thread_mtx.Get1dIndex(k, i); // A is transposed + const unsigned bindex = b_thread_mtx.Get1dIndex(k, j); + const unsigned cindex = + c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC; + + f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]); + } + } + } +#endif // read next batch of a, b if(BlockMatrixStrideA != 0) { -#pragma unroll + //#pragma unroll for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat) { +#if 0 threadwise_matrix_copy( a_block_mtx, p_a_block + @@ -492,14 +538,28 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 a_thread_mtx, p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), a_thread_sub_mtx.GetLengths()); +#else + for(unsigned i = 0; i < a_thread_mtx.NRow(); ++i) + { + for(unsigned j = 0; j < a_thread_mtx.NCol(); ++j) + { + p_a_thread[a_thread_mtx.Get1dIndex(i, + m_repeat * MPerThreadSubC + j)] = + p_a_block[a_block_mtx.Get1dIndex( + k_begin + i, m_repeat * MPerLevel1Cluster + j) + + (ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA]; + } + } +#endif } } if(BlockMatrixStrideB != 0) { -#pragma unroll + //#pragma unroll for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat) { +#if 0 threadwise_matrix_copy( b_block_mtx, p_b_block + @@ -508,11 +568,25 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 b_thread_mtx, p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), b_thread_sub_mtx.GetLengths()); +#else + for(unsigned i = 0; i < b_thread_mtx.NRow(); ++i) + { + for(unsigned j = 0; j < b_thread_mtx.NCol(); ++j) + { + p_b_thread[b_thread_mtx.Get1dIndex(i, + n_repeat * NPerThreadSubC + j)] = + p_b_block[b_block_mtx.Get1dIndex( + k_begin + i, n_repeat * MPerLevel1Cluster + j) + + (ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB]; + } + } +#endif } } } - // do last batch of gemm +// do last batch of gemm +#if 0 threadwise_gemm(a_thread_mtx, True, p_a_thread, @@ -523,6 +597,23 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 False, p_c_thread + (BatchPerThread - 1) * ThreadMatrixStrideC, f_accum); +#else + for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k) + { + for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i) + { + for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j) + { + const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed + const unsigned bindex = b_thread_mtx.Get1dIndex(k, j); + const unsigned cindex = + c_thread_mtx.Get1dIndex(i, j) + (BatchPerThread - 1) * ThreadMatrixStrideC; + + f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]); + } + } + } +#endif } }