mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
refactor
This commit is contained in:
@@ -9,7 +9,7 @@
|
||||
#include "device_direct_convolution_1.cuh"
|
||||
#include "device_direct_convolution_2.cuh"
|
||||
//#include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh"
|
||||
//#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh"
|
||||
#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh"
|
||||
#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
|
||||
//#include "device_winograd_convolution.cuh"
|
||||
|
||||
@@ -418,9 +418,9 @@ int main()
|
||||
device_direct_convolution_2
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_1_nchw_kcsr
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_1_nchw_srck_nkhw
|
||||
#elif 1
|
||||
device_implicit_gemm_convolution_1_nchw_srck_nkhw
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_2_cnhw_srck_knhw
|
||||
#elif 0
|
||||
device_winograd_convolution
|
||||
|
||||
@@ -103,6 +103,15 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
|
||||
}
|
||||
#endif
|
||||
|
||||
// blockwise copy
|
||||
// wei: format is [S,R,C,K], no conversion needed
|
||||
constexpr auto blockwise_wei_copy =
|
||||
blockwise_4d_tensor_copy_1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_srck_global_desc),
|
||||
decltype(wei_srck_block_desc),
|
||||
decltype(wei_srck_block_desc.GetLengths())>{};
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
@@ -171,13 +180,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
|
||||
#if 1
|
||||
// weight: global mem to LDS,
|
||||
// format is [S,R,C,K], no conversion needed
|
||||
blockwise_4d_tensor_copy<BlockSize>(
|
||||
wei_srck_global_desc,
|
||||
p_wei_global +
|
||||
wei_srck_global_desc.Get1dIndex(0, 0, c_block_data_begin, k_block_data_begin),
|
||||
wei_srck_block_desc,
|
||||
p_wei_block,
|
||||
wei_srck_block_desc.GetLengths());
|
||||
blockwise_wei_copy.run(p_wei_global + wei_srck_global_desc.Get1dIndex(
|
||||
0, 0, c_block_data_begin, k_block_data_begin),
|
||||
p_wei_block);
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
Reference in New Issue
Block a user