mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
refactor
This commit is contained in:
@@ -103,6 +103,15 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
|
||||
}
|
||||
#endif
|
||||
|
||||
// blockwise copy
|
||||
// wei: format is [S,R,C,K], no conversion needed
|
||||
constexpr auto blockwise_wei_copy =
|
||||
blockwise_4d_tensor_copy_1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_srck_global_desc),
|
||||
decltype(wei_srck_block_desc),
|
||||
decltype(wei_srck_block_desc.GetLengths())>{};
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
@@ -171,13 +180,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
|
||||
#if 1
|
||||
// weight: global mem to LDS,
|
||||
// format is [S,R,C,K], no conversion needed
|
||||
blockwise_4d_tensor_copy<BlockSize>(
|
||||
wei_srck_global_desc,
|
||||
p_wei_global +
|
||||
wei_srck_global_desc.Get1dIndex(0, 0, c_block_data_begin, k_block_data_begin),
|
||||
wei_srck_block_desc,
|
||||
p_wei_block,
|
||||
wei_srck_block_desc.GetLengths());
|
||||
blockwise_wei_copy.run(p_wei_global + wei_srck_global_desc.Get1dIndex(
|
||||
0, 0, c_block_data_begin, k_block_data_begin),
|
||||
p_wei_block);
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
Reference in New Issue
Block a user