another version of direct conv

This commit is contained in:
Chao Liu
2018-12-18 03:22:12 -06:00
parent 20423a3583
commit 39775d484c
12 changed files with 596 additions and 157 deletions

View File

@@ -130,6 +130,7 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc,
blockwise_4d_tensor_copy<TFloat,
decltype(in_block_src_desc),
decltype(in_block_desc),
decltype(in_block_desc),
BlockSize>(in_block_src_desc,
p_in_global +
in_global_desc.Get1dIndex(n_block_work_begin,
@@ -137,17 +138,20 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc,
hi_block_work_begin,
wi_block_work_begin),
in_block_desc,
p_in_block);
p_in_block,
in_block_desc);
// copy weight tensor to LDS
blockwise_4d_tensor_copy<TFloat,
decltype(wei_block_src_desc),
decltype(wei_block_desc),
decltype(wei_block_desc),
BlockSize>(
wei_block_src_desc,
p_wei_global + wei_global_desc.Get1dIndex(k_block_work_begin, c_block_work_begin, 0, 0),
wei_block_desc,
p_wei_block);
p_wei_block,
wei_block_desc);
__syncthreads();
@@ -171,11 +175,13 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc,
blockwise_4d_tensor_copy<TFloat,
decltype(out_block_desc),
decltype(out_block_src_desc),
decltype(out_block_desc),
BlockSize>(
out_block_desc,
p_out_block,
out_block_src_desc,
p_out_global +
out_global_desc.Get1dIndex(
n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin));
n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin),
out_block_desc);
}