This commit is contained in:
Chao Liu
2019-01-24 16:15:51 -06:00
parent c39c573eb8
commit bd811e2c20
2 changed files with 15 additions and 10 deletions

View File

@@ -103,6 +103,15 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
}
#endif
// blockwise copy
// wei: format is [S,R,C,K], no conversion needed
constexpr auto blockwise_wei_copy =
blockwise_4d_tensor_copy_1<BlockSize,
Float,
decltype(wei_srck_global_desc),
decltype(wei_srck_block_desc),
decltype(wei_srck_block_desc.GetLengths())>{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
@@ -171,13 +180,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
#if 1
// weight: global mem to LDS,
// format is [S,R,C,K], no conversion needed
blockwise_4d_tensor_copy<BlockSize>(
wei_srck_global_desc,
p_wei_global +
wei_srck_global_desc.Get1dIndex(0, 0, c_block_data_begin, k_block_data_begin),
wei_srck_block_desc,
p_wei_block,
wei_srck_block_desc.GetLengths());
blockwise_wei_copy.run(p_wei_global + wei_srck_global_desc.Get1dIndex(
0, 0, c_block_data_begin, k_block_data_begin),
p_wei_block);
#endif
__syncthreads();