diff --git a/src/include/blockwise_batched_gemm.hip.hpp b/src/include/blockwise_batched_gemm.hip.hpp index 4959963f2b..8b8e4d9566 100644 --- a/src/include/blockwise_batched_gemm.hip.hpp +++ b/src/include/blockwise_batched_gemm.hip.hpp @@ -211,40 +211,45 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 #pragma unroll for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) { -// read first batch of A, B -// copy A-sub to form A -#pragma unroll - for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) - { - threadwise_matrix_copy( - a_block_mtx, - p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) + - mMyThreadOffsetA, - a_thread_mtx, - p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), - a_thread_sub_mtx.GetLengths(), - Number{}); - } - -// copy B-sub to form B -#pragma unroll - for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) - { - threadwise_matrix_copy( - b_block_mtx, - p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) + - mMyThreadOffsetB, - b_thread_mtx, - p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), - b_thread_sub_mtx.GetLengths(), - Number{}); - } - // loop over batch #pragma unroll - for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib) + for(index_t ib = 0; ib < BatchPerThread; ++ib) { - // do current batch of gemm + // read next batch of a, b + if(BlockMatrixStrideA != 0 or ib == 0) + { +#pragma unroll + for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + { + threadwise_matrix_copy( + a_block_mtx, + p_a_block + + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) + + ib * BlockMatrixStrideA + mMyThreadOffsetA, + a_thread_mtx, + p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), + a_thread_sub_mtx.GetLengths(), + Number{}); + } + } + + if(BlockMatrixStrideB != 0 or ib == 0) + { +#pragma unroll + for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + { + threadwise_matrix_copy( + b_block_mtx, + p_b_block + + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) + + ib * BlockMatrixStrideB + mMyThreadOffsetB, + b_thread_mtx, + p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), + b_thread_sub_mtx.GetLengths(), + Number{}); + } + } + threadwise_gemm(a_thread_mtx, True, p_a_thread, @@ -255,52 +260,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 False, p_c_thread + ib * ThreadMatrixStrideC); - // read next batch of a, b - if(BlockMatrixStrideA != 0) - { -#pragma unroll - for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) - { - threadwise_matrix_copy( - a_block_mtx, - p_a_block + - a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) + - (ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA, - a_thread_mtx, - p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), - a_thread_sub_mtx.GetLengths(), - Number{}); - } - } - - if(BlockMatrixStrideB != 0) - { -#pragma unroll - for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) - { - threadwise_matrix_copy( - b_block_mtx, - p_b_block + - b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) + - (ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB, - b_thread_mtx, - p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), - b_thread_sub_mtx.GetLengths(), - Number{}); - } - } } - - // do last batch of gemm - threadwise_gemm(a_thread_mtx, - True, - p_a_thread, - b_thread_mtx, - False, - p_b_thread, - c_thread_mtx, - False, - p_c_thread + (BatchPerThread - 1) * ThreadMatrixStrideC); } }