This commit is contained in:
Chao Liu
2019-01-08 16:56:46 -06:00
parent 0b8e67ef08
commit df228b3cf5
5 changed files with 57 additions and 62 deletions

View File

@@ -69,8 +69,8 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
constexpr auto wei_thread_block_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread, CPerThread, S, R>{}, wei_block_desc.GetStrides());
constexpr auto out_thread_desc =
get_convolution_output_4d_tensor_descriptor(in_thread_block_desc, wei_thread_block_desc);
constexpr auto out_thread_desc = get_convolution_output_default_4d_tensor_descriptor(
in_thread_block_desc, wei_thread_block_desc);
// register
Float p_out_thread[out_thread_desc.GetElementSpace()];