mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
unroll some loop, register double buffer gemm
This commit is contained in:
@@ -611,7 +611,7 @@ int main()
|
||||
nrepeat);
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
if(S == 3 && R == 3)
|
||||
{
|
||||
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
|
||||
|
||||
@@ -66,42 +66,12 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw_gemm_2(InDesc,
|
||||
|
||||
Tensor<T> out_knhw(make_TensorDescriptor(out_knhw_desc));
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
// 1x1, 28x28
|
||||
constexpr unsigned BPerBlock = 64;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 8;
|
||||
|
||||
constexpr unsigned BPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
|
||||
constexpr unsigned GemmMPerThreadSubC = 16;
|
||||
constexpr unsigned GemmNPerThreadSubC = 4;
|
||||
constexpr unsigned GemmMLevel0Cluster = 4;
|
||||
constexpr unsigned GemmNLevel0Cluster = 8;
|
||||
constexpr unsigned GemmMLevel1Cluster = 1;
|
||||
constexpr unsigned GemmNLevel1Cluster = 2;
|
||||
constexpr unsigned GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 8;
|
||||
|
||||
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 4;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 64;
|
||||
#elif 1
|
||||
// 1x1, 28x28 try
|
||||
constexpr unsigned BPerBlock = 64;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 8;
|
||||
|
||||
constexpr unsigned BPerThread = 8;
|
||||
constexpr unsigned KPerThread = 8;
|
||||
|
||||
|
||||
@@ -598,9 +598,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
#pragma unroll
|
||||
// loop over k
|
||||
for(unsigned k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
{
|
||||
#pragma unroll
|
||||
// copy A-sub to form A
|
||||
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
@@ -613,6 +615,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
// copy B-sub to form B
|
||||
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
@@ -638,4 +641,148 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC, class Accumulator>
|
||||
__device__ void Run_RegisterDoubleBuffer(FloatA* const p_a_block,
|
||||
FloatB* const p_b_block,
|
||||
FloatC* p_c_thread,
|
||||
Accumulator f_accum) const
|
||||
{
|
||||
constexpr auto True = Constant<bool, true>{};
|
||||
constexpr auto False = Constant<bool, false>{};
|
||||
|
||||
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile
|
||||
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile
|
||||
const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile
|
||||
|
||||
constexpr unsigned M = a_block_mtx.NCol();
|
||||
constexpr unsigned N = b_block_mtx.NCol();
|
||||
constexpr unsigned K = a_block_mtx.NRow();
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
const auto a_thread_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<MPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
const auto b_thread_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<NPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
// thread A-sub, B-sub for copy
|
||||
const auto a_thread_sub_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{},
|
||||
Number<MPerThreadSubC>{},
|
||||
Number<MPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
const auto b_thread_sub_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{},
|
||||
Number<NPerThreadSubC>{},
|
||||
Number<NPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
FloatA p_a_thread_0[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread_0[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
|
||||
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
// preload A, B
|
||||
#pragma unroll
|
||||
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{ // copy A-sub to form A
|
||||
threadwise_matrix_copy(a_block_mtx,
|
||||
p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
|
||||
a_thread_sub_mtx,
|
||||
p_a_thread_0 + m_repeat * MPerThreadSubC,
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{ // copy B-sub to form B
|
||||
threadwise_matrix_copy(b_block_mtx,
|
||||
p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
|
||||
b_thread_sub_mtx,
|
||||
p_b_thread_0 + n_repeat * NPerThreadSubC,
|
||||
b_thread_sub_mtx.GetLengths());
|
||||
}
|
||||
|
||||
bool even_loop = true;
|
||||
|
||||
#pragma unroll
|
||||
for(unsigned k_begin = 0; k_begin + 1 < K;
|
||||
k_begin += KPerThreadLoop, even_loop = !even_loop)
|
||||
{ // loop over k
|
||||
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
|
||||
FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;
|
||||
|
||||
FloatA* p_a_thread_next = even_loop ? p_a_thread_1 : p_a_thread_0;
|
||||
FloatB* p_b_thread_next = even_loop ? p_b_thread_1 : p_b_thread_0;
|
||||
|
||||
// preload next A, B
|
||||
|
||||
#pragma unroll
|
||||
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{ // copy A-sub to form A
|
||||
threadwise_matrix_copy(a_block_mtx,
|
||||
p_a_block + mMyThreadOffsetA +
|
||||
(k_begin + 1) * a_block_mtx.RowStride() +
|
||||
m_repeat * MPerLevel1Cluster,
|
||||
a_thread_sub_mtx,
|
||||
p_a_thread_next + m_repeat * MPerThreadSubC,
|
||||
a_thread_sub_mtx.GetLengths());
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{ // copy B-sub to form B
|
||||
threadwise_matrix_copy(b_block_mtx,
|
||||
p_b_block + mMyThreadOffsetB +
|
||||
(k_begin + 1) * b_block_mtx.RowStride() +
|
||||
n_repeat * NPerLevel1Cluster,
|
||||
b_thread_sub_mtx,
|
||||
p_b_thread_next + n_repeat * NPerThreadSubC,
|
||||
b_thread_sub_mtx.GetLengths());
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
threadwise_gemm(a_thread_mtx,
|
||||
True,
|
||||
p_a_thread_now,
|
||||
b_thread_mtx,
|
||||
False,
|
||||
p_b_thread_now,
|
||||
c_thread_mtx,
|
||||
False,
|
||||
p_c_thread,
|
||||
f_accum);
|
||||
}
|
||||
|
||||
// last loop
|
||||
{
|
||||
even_loop = !even_loop;
|
||||
|
||||
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
|
||||
FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;
|
||||
|
||||
// C = A * B
|
||||
threadwise_gemm(a_thread_mtx,
|
||||
True,
|
||||
p_a_thread_now,
|
||||
b_thread_mtx,
|
||||
False,
|
||||
p_b_thread_now,
|
||||
c_thread_mtx,
|
||||
False,
|
||||
p_c_thread,
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -237,10 +237,15 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
|
||||
{
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
|
||||
blockwise_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_in_block + s * Wi + r,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
#if 1
|
||||
blockwise_gemm.Run
|
||||
#else
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#endif
|
||||
(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_in_block + s * Wi + r,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user