This commit is contained in:
Chao Liu
2018-11-15 23:22:06 -06:00
parent 99d05ba77f
commit adf4b173b3
3 changed files with 176 additions and 172 deletions

View File

@@ -12,11 +12,11 @@ template <class TFloat,
unsigned OutTileSizeW,
unsigned BlockSize>
__device__ void blockwise_convolution(InDesc,
TFloat* const __restrict__ p_in,
TFloat* const __restrict__ p_in_lds,
WeiDesc,
TFloat* const __restrict__ p_wei,
TFloat* const __restrict__ p_wei_lds,
OutDesc,
TFloat* __restrict__ p_out)
TFloat* __restrict__ p_out_lds)
{
constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{};
@@ -97,8 +97,8 @@ __device__ void blockwise_convolution(InDesc,
decltype(in_thread_dst_desc),
decltype(f_copy)>(
in_thread_src_desc,
p_in + in_desc.Get1dIndex(
n_thread_work_begin, 0, hi_thread_work_begin, wi_thread_work_begin),
p_in_lds + in_desc.Get1dIndex(
n_thread_work_begin, 0, hi_thread_work_begin, wi_thread_work_begin),
in_thread_dst_desc,
p_in_thread,
f_copy);
@@ -112,7 +112,7 @@ __device__ void blockwise_convolution(InDesc,
decltype(wei_thread_dst_desc),
decltype(f_copy)>(
wei_thread_src_desc,
p_wei + wei_desc.Get1dIndex(k_thread_work_begin, 0, 0, 0),
p_wei_lds + wei_desc.Get1dIndex(k_thread_work_begin, 0, 0, 0),
wei_thread_dst_desc,
p_wei_thread,
f_copy);
@@ -123,10 +123,10 @@ __device__ void blockwise_convolution(InDesc,
decltype(out_thread_dst_desc),
decltype(f_copy)>(
out_thread_src_desc,
p_out + out_desc.Get1dIndex(n_thread_work_begin,
k_thread_work_begin,
ho_thread_work_begin,
wo_thread_work_begin),
p_out_lds + out_desc.Get1dIndex(n_thread_work_begin,
k_thread_work_begin,
ho_thread_work_begin,
wo_thread_work_begin),
out_thread_dst_desc,
p_out_thread,
f_copy);
@@ -150,10 +150,10 @@ __device__ void blockwise_convolution(InDesc,
out_thread_dst_desc,
p_out_thread,
out_thread_src_desc,
p_out + out_desc.Get1dIndex(n_thread_work_begin,
k_thread_work_begin,
ho_thread_work_begin,
wo_thread_work_begin),
p_out_lds + out_desc.Get1dIndex(n_thread_work_begin,
k_thread_work_begin,
ho_thread_work_begin,
wo_thread_work_begin),
f_copy);
}
}
@@ -170,18 +170,18 @@ template <class TFloat,
unsigned CPerBlock,
unsigned YPerBlock,
unsigned XPerBlock,
unsigned NBlockCopyLen0,
unsigned NBlockCopyLen1,
unsigned NBlockCopyLen2,
unsigned NBlockCopyLen3,
unsigned NBlockOpLen0,
unsigned NBlockOpLen1,
unsigned NBlockOpLen2,
unsigned NBlockOpLen3,
unsigned BlockSize,
unsigned GridSize>
__global__ void gridwise_convolution(InDesc,
TFloat* const __restrict__ p_in,
TFloat* const __restrict__ p_in_glb,
WeiDesc,
TFloat* const __restrict__ p_wei,
TFloat* const __restrict__ p_wei_glb,
OutDesc,
TFloat* __restrict__ p_out)
TFloat* __restrict__ p_out_glb)
{
constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{};
@@ -222,13 +222,13 @@ __global__ void gridwise_convolution(InDesc,
constexpr auto out_block_lds_desc =
make_ConstantTensorDescriptor(out_block_glb_desc.GetLengths());
constexpr unsigned in_block_size = in_block_lds_desc.GetElementSize();
constexpr unsigned wei_block_size = wei_block_lds_desc.GetElementSize();
constexpr unsigned out_block_size = out_block_lds_desc.GetElementSize();
constexpr unsigned in_block_size = in_block_lds_desc.GetElementSpace();
constexpr unsigned wei_block_size = wei_block_lds_desc.GetElementSpace();
constexpr unsigned out_block_size = out_block_lds_desc.GetElementSpace();
__shared__ TFloat p_in_block[in_block_size];
__shared__ TFloat p_wei_block[wei_block_size];
__shared__ TFloat p_out_block[out_block_size];
__shared__ TFloat p_in_block_lds[in_block_size];
__shared__ TFloat p_wei_block_lds[wei_block_size];
__shared__ TFloat p_out_block_lds[out_block_size];
const unsigned block_id = blockIdx.x;
@@ -286,12 +286,12 @@ __global__ void gridwise_convolution(InDesc,
// set output tensor in LDS to 0
blockwise_4d_tensor_op_unary<TFloat,
decltype(out_block_lds_desc),
NBlockCopyLen0,
NBlockCopyLen1,
NBlockCopyLen2,
NBlockCopyLen3,
NBlockOpLen0,
NBlockOpLen1,
NBlockOpLen2,
NBlockOpLen3,
decltype(f_set0),
BlockSize>(out_block_lds_desc, p_out_block, f_set0);
BlockSize>(out_block_lds_desc, p_out_block_lds, f_set0);
for(unsigned c_block_work_begin = 0; c_block_work_begin < in_desc.GetLength(I1);
c_block_work_begin += CPerBlock)
@@ -301,35 +301,35 @@ __global__ void gridwise_convolution(InDesc,
blockwise_4d_tensor_op_binary<TFloat,
decltype(in_block_glb_desc),
decltype(in_block_lds_desc),
NBlockCopyLen0,
NBlockCopyLen1,
NBlockCopyLen2,
NBlockCopyLen3,
NBlockOpLen0,
NBlockOpLen1,
NBlockOpLen2,
NBlockOpLen3,
decltype(f_copy),
BlockSize>(
in_block_glb_desc,
p_in + in_block_glb_desc.Get1dIndex(n_block_work_begin,
c_block_work_begin,
hi_block_work_begin,
wi_block_work_begin),
p_in_glb + in_block_glb_desc.Get1dIndex(n_block_work_begin,
c_block_work_begin,
hi_block_work_begin,
wi_block_work_begin),
in_block_lds_desc,
p_in_block,
p_in_block_lds,
f_copy);
// copy weight tensor to LDS
blockwise_4d_tensor_op_binary<TFloat,
decltype(wei_block_glb_desc),
decltype(wei_block_lds_desc),
NBlockCopyLen0,
NBlockCopyLen1,
NBlockCopyLen2,
NBlockCopyLen3,
NBlockOpLen0,
NBlockOpLen1,
NBlockOpLen2,
NBlockOpLen3,
decltype(f_copy),
BlockSize>(
wei_block_glb_desc,
p_wei + wei_block_glb_desc.Get1dIndex(k_block_work_begin, c_block_work_begin, 0, 0),
p_wei_glb + wei_block_glb_desc.Get1dIndex(k_block_work_begin, c_block_work_begin, 0, 0),
wei_block_lds_desc,
p_wei_block,
p_wei_block_lds,
f_copy);
#if 1
@@ -344,11 +344,11 @@ __global__ void gridwise_convolution(InDesc,
OutTileSizeH,
OutTileSizeW,
BlockSize>(in_block_lds_desc,
p_in_block,
p_in_block_lds,
wei_block_lds_desc,
p_wei_block,
p_wei_block_lds,
out_block_lds_desc,
p_out_block);
p_out_block_lds);
#if 1
__syncthreads();
@@ -359,16 +359,16 @@ __global__ void gridwise_convolution(InDesc,
blockwise_4d_tensor_op_binary<TFloat,
decltype(out_block_lds_desc),
decltype(out_block_glb_desc),
NBlockCopyLen0,
NBlockCopyLen1,
NBlockCopyLen2,
NBlockCopyLen3,
NBlockOpLen0,
NBlockOpLen1,
NBlockOpLen2,
NBlockOpLen3,
decltype(f_copy),
BlockSize>(
out_block_lds_desc,
p_out_block,
p_out_block_lds,
out_block_glb_desc,
p_out +
p_out_glb +
out_block_glb_desc.Get1dIndex(
n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin),
f_copy);