mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
simplify blockwise batched GEMM
This commit is contained in:
@@ -211,40 +211,45 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
#pragma unroll
|
||||
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
|
||||
{
|
||||
// read first batch of A, B
|
||||
// copy A-sub to form A
|
||||
#pragma unroll
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
a_block_mtx,
|
||||
p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
|
||||
mMyThreadOffsetA,
|
||||
a_thread_mtx,
|
||||
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
|
||||
// copy B-sub to form B
|
||||
#pragma unroll
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
b_block_mtx,
|
||||
p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
|
||||
mMyThreadOffsetB,
|
||||
b_thread_mtx,
|
||||
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
|
||||
// loop over batch
|
||||
#pragma unroll
|
||||
for(index_t ib = 0; ib + 1 < BatchPerThread; ++ib)
|
||||
for(index_t ib = 0; ib < BatchPerThread; ++ib)
|
||||
{
|
||||
// do current batch of gemm
|
||||
// read next batch of a, b
|
||||
if(BlockMatrixStrideA != 0 or ib == 0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
a_block_mtx,
|
||||
p_a_block +
|
||||
a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
|
||||
ib * BlockMatrixStrideA + mMyThreadOffsetA,
|
||||
a_thread_mtx,
|
||||
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
}
|
||||
|
||||
if(BlockMatrixStrideB != 0 or ib == 0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
b_block_mtx,
|
||||
p_b_block +
|
||||
b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
|
||||
ib * BlockMatrixStrideB + mMyThreadOffsetB,
|
||||
b_thread_mtx,
|
||||
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
}
|
||||
|
||||
threadwise_gemm(a_thread_mtx,
|
||||
True,
|
||||
p_a_thread,
|
||||
@@ -255,52 +260,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
False,
|
||||
p_c_thread + ib * ThreadMatrixStrideC);
|
||||
|
||||
// read next batch of a, b
|
||||
if(BlockMatrixStrideA != 0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
a_block_mtx,
|
||||
p_a_block +
|
||||
a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
|
||||
(ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA,
|
||||
a_thread_mtx,
|
||||
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
}
|
||||
|
||||
if(BlockMatrixStrideB != 0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
b_block_mtx,
|
||||
p_b_block +
|
||||
b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
|
||||
(ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB,
|
||||
b_thread_mtx,
|
||||
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// do last batch of gemm
|
||||
threadwise_gemm(a_thread_mtx,
|
||||
True,
|
||||
p_a_thread,
|
||||
b_thread_mtx,
|
||||
False,
|
||||
p_b_thread,
|
||||
c_thread_mtx,
|
||||
False,
|
||||
p_c_thread + (BatchPerThread - 1) * ThreadMatrixStrideC);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user