refactor direct

This commit is contained in:
Chao Liu
2018-11-25 01:10:11 -06:00
parent 8732ea04fb
commit 24d2f034fa
14 changed files with 253 additions and 1291 deletions

View File

@@ -27,10 +27,10 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc,
OutGlobalDesc,
TFloat* __restrict__ p_out_global)
{
constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{};
constexpr auto I2 = Index<2>{};
constexpr auto I3 = Index<3>{};
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_global_desc = InGlobalDesc{};
constexpr auto wei_global_desc = WeiGlobalDesc{};
@@ -120,62 +120,38 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc,
}
#endif
auto f_set0 = [](TFloat& v) { v = TFloat(0); };
auto f_copy = [](const TFloat& src, TFloat& dst) { dst = src; };
auto f_accu = [](const TFloat& src, TFloat& dst) { dst += src; };
// set output tensor in LDS to 0
blockwise_4d_tensor_op_unary<TFloat,
decltype(out_block_desc),
NBlockOpLen0,
NBlockOpLen1,
NBlockOpLen2,
NBlockOpLen3,
decltype(f_set0),
BlockSize>(out_block_desc, p_out_block, f_set0);
blockwise_4d_tensor_set_zero<TFloat, decltype(out_block_desc), BlockSize>(out_block_desc,
p_out_block);
for(unsigned c_block_work_begin = 0; c_block_work_begin < in_global_desc.GetLength(I1);
c_block_work_begin += CPerBlock)
c_block_work_begin += CPerBlock, __syncthreads())
{
// copy input tensor to LDS
blockwise_4d_tensor_op_binary<TFloat,
decltype(in_block_src_desc),
decltype(in_block_desc),
NBlockOpLen0,
NBlockOpLen1,
NBlockOpLen2,
NBlockOpLen3,
decltype(f_copy),
BlockSize>(in_block_src_desc,
p_in_global +
in_global_desc.Get1dIndex(n_block_work_begin,
c_block_work_begin,
hi_block_work_begin,
wi_block_work_begin),
in_block_desc,
p_in_block,
f_copy);
blockwise_4d_tensor_copy<TFloat,
decltype(in_block_src_desc),
decltype(in_block_desc),
BlockSize>(in_block_src_desc,
p_in_global +
in_global_desc.Get1dIndex(n_block_work_begin,
c_block_work_begin,
hi_block_work_begin,
wi_block_work_begin),
in_block_desc,
p_in_block);
// copy weight tensor to LDS
blockwise_4d_tensor_op_binary<TFloat,
decltype(wei_block_src_desc),
decltype(wei_block_desc),
NBlockOpLen0,
NBlockOpLen1,
NBlockOpLen2,
NBlockOpLen3,
decltype(f_copy),
BlockSize>(
blockwise_4d_tensor_copy<TFloat,
decltype(wei_block_src_desc),
decltype(wei_block_desc),
BlockSize>(
wei_block_src_desc,
p_wei_global + wei_global_desc.Get1dIndex(k_block_work_begin, c_block_work_begin, 0, 0),
wei_block_desc,
p_wei_block,
f_copy);
p_wei_block);
#if 1
__syncthreads();
#endif
// blockwise convolution
blockwise_convolution<TFloat,
@@ -186,27 +162,17 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc,
OutTileSizeW,
BlockSize>(
in_block_desc, p_in_block, wei_block_desc, p_wei_block, out_block_desc, p_out_block);
#if 1
__syncthreads();
#endif
}
// copy output tensor from LDS to device mem
blockwise_4d_tensor_op_binary<TFloat,
decltype(out_block_desc),
decltype(out_block_src_desc),
NBlockOpLen0,
NBlockOpLen1,
NBlockOpLen2,
NBlockOpLen3,
decltype(f_copy),
BlockSize>(
blockwise_4d_tensor_copy<TFloat,
decltype(out_block_desc),
decltype(out_block_src_desc),
BlockSize>(
out_block_desc,
p_out_block,
out_block_src_desc,
p_out_global +
out_global_desc.Get1dIndex(
n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin),
f_copy);
}
n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin));
}