This commit is contained in:
Chao Liu
2018-11-28 16:20:01 -06:00
parent fee92fb636
commit 1eafc9c1fb
4 changed files with 27 additions and 22 deletions

View File

@@ -41,8 +41,8 @@ __global__ void gridwise_winograd_convolution(InGlobalDesc,
constexpr unsigned HoPerBlock = OutTileSizeH * YPerBlock;
constexpr unsigned WoPerBlock = OutTileSizeW * XPerBlock;
constexpr unsigned HiPerBlock = YPerBlock * OutTileSizeH + S - 1;
constexpr unsigned WiPerBlock = XPerBlock * OutTileSizeW + R - 1;
constexpr unsigned HiPerBlock = HoPerBlock + S - 1;
constexpr unsigned WiPerBlock = WoPerBlock + R - 1;
constexpr unsigned InTileSizeH = OutTileSizeH + S - 1;
constexpr unsigned InTileSizeW = OutTileSizeW + R - 1;
@@ -102,11 +102,8 @@ __global__ void gridwise_winograd_convolution(InGlobalDesc,
constexpr auto wei_transform_block_desc =
make_ConstantTensorDescriptor(Sequence<KPerBlock, CPerBlock, InTileSizeH, InTileSizeW>{});
constexpr unsigned in_transform_block_size = in_transform_block_desc.GetElementSpace();
constexpr unsigned wei_transform_block_size = wei_transform_block_desc.GetElementSpace();
__shared__ TFloat p_in_transform_block[in_transform_block_size];
__shared__ TFloat p_wei_transform_block[wei_transform_block_size];
__shared__ TFloat p_in_transform_block[in_transform_block_desc.GetElementSpace()];
__shared__ TFloat p_wei_transform_block[wei_transform_block_desc.GetElementSpace()];
// thread data
constexpr auto in_transform_thread_block_desc =
@@ -126,11 +123,8 @@ __global__ void gridwise_winograd_convolution(InGlobalDesc,
constexpr auto out_thread_global_desc =
make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_global_desc.GetStrides());
constexpr unsigned out_transform_thread_size = out_transform_thread_desc.GetElementSpace();
constexpr unsigned out_thread_size = out_thread_desc.GetElementSpace();
TFloat p_out_transform_thread[out_transform_thread_size];
TFloat p_out_thread[out_thread_size];
TFloat p_out_transform_thread[out_transform_thread_desc.GetElementSpace()];
TFloat p_out_thread[out_thread_desc.GetElementSpace()];
#if 0
if(blockIdx.x == 0 && threadIdx.x == 0)

View File

@@ -116,10 +116,13 @@ __device__ void threadwise_4d_tensor_shift_down(Desc, TFloat* __restrict__ p, ID
const unsigned did0_end =
is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - shift : desc.GetLength(I0);
const unsigned did1_end =
is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - shift : desc.GetLength(I1);
const unsigned did2_end =
is_same<decltype(I2), IDim>::value ? desc.GetLength(I2) - shift : desc.GetLength(I2);
const unsigned did3_end =
is_same<decltype(I3), IDim>::value ? desc.GetLength(I3) - shift : desc.GetLength(I3);