another version of direct conv

This commit is contained in:
Chao Liu
2018-12-18 03:22:12 -06:00
parent 20423a3583
commit 39775d484c
12 changed files with 596 additions and 157 deletions

View File

@@ -46,15 +46,11 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
constexpr unsigned HiPerBlock = YPerBlock * OutTileSizeH + S - 1;
constexpr unsigned WiPerBlock = XPerBlock * OutTileSizeW + R - 1;
constexpr auto in_block_global_desc = make_ConstantTensorDescriptor(
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, in_global_desc.GetStrides());
constexpr auto in_block_desc =
make_ConstantTensorDescriptor(Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{});
constexpr auto wei_block_global_desc = make_ConstantTensorDescriptor(
Sequence<KPerBlock, CPerBlock, S, R>{}, wei_global_desc.GetStrides());
constexpr auto in_block_desc = make_ConstantTensorDescriptor(in_block_global_desc.GetLengths());
constexpr auto wei_block_desc =
make_ConstantTensorDescriptor(wei_block_global_desc.GetLengths());
make_ConstantTensorDescriptor(Sequence<KPerBlock, CPerBlock, S, R>{});
// shared mem
constexpr unsigned in_block_size = in_block_desc.GetElementSpace();
@@ -67,30 +63,19 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
constexpr unsigned InTileSizeH = OutTileSizeH + S - 1;
constexpr unsigned InTileSizeW = OutTileSizeW + R - 1;
constexpr auto in_thread_block_desc = make_ConstantTensorDescriptor(
Sequence<NPerThread, CPerThread, InTileSizeH, InTileSizeW>{}, in_block_desc.GetStrides());
constexpr auto wei_thread_block_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread, CPerThread, S, R>{}, wei_block_desc.GetStrides());
constexpr auto in_thread_desc =
make_ConstantTensorDescriptor(in_thread_block_desc.GetLengths());
make_ConstantTensorDescriptor(Sequence<NPerThread, CPerThread, InTileSizeH, InTileSizeW>{});
constexpr auto wei_thread_desc =
make_ConstantTensorDescriptor(wei_thread_block_desc.GetLengths());
make_ConstantTensorDescriptor(Sequence<KPerThread, CPerThread, S, R>{});
constexpr auto out_thread_desc =
get_output_4d_tensor_descriptor(in_thread_desc, wei_thread_desc);
constexpr auto out_thread_global_desc =
make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_global_desc.GetStrides());
// register
constexpr unsigned in_thread_size = in_thread_desc.GetElementSpace();
constexpr unsigned wei_thread_size = wei_thread_desc.GetElementSpace();
constexpr unsigned out_thread_size = out_thread_desc.GetElementSpace();
TFloat p_in_thread[in_thread_size];
TFloat p_wei_thread[wei_thread_size];
TFloat p_out_thread[out_thread_size];
TFloat p_in_thread[in_thread_desc.GetElementSpace()];
TFloat p_wei_thread[wei_thread_desc.GetElementSpace()];
TFloat p_out_thread[out_thread_desc.GetElementSpace()];
// divide block work
constexpr unsigned NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
@@ -169,54 +154,60 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
{
// copy input tensor to LDS
blockwise_4d_tensor_copy<TFloat,
decltype(in_block_global_desc),
decltype(in_global_desc),
decltype(in_block_desc),
BlockSize>(in_block_global_desc,
decltype(in_block_desc),
BlockSize>(in_global_desc,
p_in_global +
in_global_desc.Get1dIndex(n_block_data_begin,
c_block_data_begin,
hi_block_data_begin,
wi_block_data_begin),
in_block_desc,
p_in_block);
p_in_block,
in_block_desc);
// copy weight tensor to LDS
blockwise_4d_tensor_copy<TFloat,
decltype(wei_block_global_desc),
decltype(wei_global_desc),
decltype(wei_block_desc),
decltype(wei_block_desc),
BlockSize>(
wei_block_global_desc,
wei_global_desc,
p_wei_global + wei_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
wei_block_desc,
p_wei_block);
p_wei_block,
wei_block_desc);
__syncthreads();
for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
{
// copy input tensor into register
threadwise_4d_tensor_copy(in_thread_block_desc,
threadwise_4d_tensor_copy(in_block_desc,
p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data,
hi_thread_data_begin,
wi_thread_data_begin),
in_thread_desc,
p_in_thread);
p_in_thread,
in_thread_desc);
// copy weight tensor into register
threadwise_4d_tensor_copy(
wei_thread_block_desc,
wei_block_desc,
p_wei_block + wei_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
wei_thread_desc,
p_wei_thread);
p_wei_thread,
wei_thread_desc);
// threadwise convolution
threadwise_direct_convolution(in_thread_desc,
p_in_thread,
wei_thread_desc,
p_wei_thread,
out_thread_desc,
p_out_thread);
threadwise_direct_convolution_1(in_thread_desc,
p_in_thread,
wei_thread_desc,
p_wei_thread,
out_thread_desc,
p_out_thread);
}
}
@@ -224,9 +215,10 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
threadwise_4d_tensor_copy(
out_thread_desc,
p_out_thread,
out_thread_global_desc,
out_global_desc,
p_out_global + out_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin));
wo_block_data_begin + wo_thread_data_begin),
out_thread_desc);
}