mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
refactor
This commit is contained in:
@@ -401,10 +401,10 @@ int main()
|
||||
#elif 0
|
||||
device_direct_convolution_2(
|
||||
in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device);
|
||||
#elif 0
|
||||
#elif 1
|
||||
device_implicit_gemm_convolution_nchw_kcsr(
|
||||
in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device);
|
||||
#elif 1
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_nchw_srck(
|
||||
in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device);
|
||||
#elif 0
|
||||
|
||||
@@ -58,22 +58,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
constexpr unsigned WBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
|
||||
|
||||
unsigned itmp = get_block_1d_id();
|
||||
const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
|
||||
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
|
||||
const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork);
|
||||
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
|
||||
const unsigned h_block_work_id = itmp / WBlockWork;
|
||||
const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork;
|
||||
|
||||
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const unsigned wo_block_data_begin = w_block_work_id * HoPerBlock;
|
||||
|
||||
const unsigned hi_block_data_begin = ho_block_data_begin;
|
||||
const unsigned wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// tensor view of un-reorderd blockwise input and weight (imaginary)
|
||||
constexpr auto in_nchw_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{});
|
||||
@@ -106,6 +90,23 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
}
|
||||
#endif
|
||||
|
||||
// my block work
|
||||
unsigned itmp = get_block_1d_id();
|
||||
const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
|
||||
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
|
||||
const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork);
|
||||
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
|
||||
const unsigned h_block_work_id = itmp / WBlockWork;
|
||||
const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork;
|
||||
|
||||
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const unsigned wo_block_data_begin = w_block_work_id * HoPerBlock;
|
||||
|
||||
const unsigned hi_block_data_begin = ho_block_data_begin;
|
||||
const unsigned wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
@@ -156,6 +157,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1);
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
#if 1
|
||||
// input: global mem to LDS,
|
||||
// convert [N,C,Hi,Wi] to [C,Hi,Wi,N]
|
||||
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
|
||||
@@ -168,7 +170,21 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
p_in_block,
|
||||
in_nchw_block_desc.GetLengths(),
|
||||
reorder_chwn_from_nchw);
|
||||
#else
|
||||
// input: global mem to LDS,
|
||||
// no format conversion, this is wrong, for performance study only!
|
||||
blockwise_4d_tensor_copy<BlockSize>(in_nchw_global_desc,
|
||||
p_in_global +
|
||||
in_nchw_global_desc.Get1dIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
in_nchw_block_desc,
|
||||
p_in_block,
|
||||
in_nchw_block_desc.GetLengths());
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
// weight: global mem to LDS,
|
||||
// convert [K,C,S,R] to [S,R,C,K]
|
||||
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
|
||||
@@ -179,6 +195,17 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
p_wei_block,
|
||||
wei_kcsr_block_desc.GetLengths(),
|
||||
reorder_srck_from_kcsr);
|
||||
#else
|
||||
// weight: global mem to LDS,
|
||||
// no format conversion, this is wrong, for performance study only!
|
||||
blockwise_4d_tensor_copy<BlockSize>(
|
||||
wei_kcsr_global_desc,
|
||||
p_wei_global +
|
||||
wei_kcsr_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
wei_kcsr_block_desc,
|
||||
p_wei_block,
|
||||
wei_kcsr_block_desc.GetLengths());
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -204,6 +231,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
|
||||
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerThread;
|
||||
|
||||
#if 1
|
||||
// output: register to global mem,
|
||||
// convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo]
|
||||
constexpr auto reorder_nkhw_from_hkwn = Sequence<3, 1, 0, 2>{};
|
||||
@@ -218,4 +246,21 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_hkwn_thread_desc.GetLengths(),
|
||||
reorder_nkhw_from_hkwn);
|
||||
#else
|
||||
// output: register to global mem,
|
||||
// no format conversion, assume register is in [N,K,Ho,Wo], this is wrong, for performance
|
||||
// study only!
|
||||
constexpr auto out_nkhw_thread_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<NPerThread, KPerThread, HoPerThread, WoPerThread>{});
|
||||
|
||||
threadwise_4d_tensor_copy(
|
||||
out_nkhw_thread_desc,
|
||||
p_out_thread,
|
||||
out_nkhw_global_desc,
|
||||
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_nkhw_thread_desc.GetLengths());
|
||||
#endif
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user