diff --git a/src/include/direct_convolution_2.cuh b/src/include/direct_convolution_2.cuh index 706984f898..9ca2f0e4e4 100644 --- a/src/include/direct_convolution_2.cuh +++ b/src/include/direct_convolution_2.cuh @@ -59,16 +59,13 @@ __global__ void gridwise_convolution(InGlobalDesc, constexpr auto out_block_src_desc = make_ConstantTensorDescriptor( Sequence{}, out_global_desc.GetStrides()); - constexpr auto in_block_dst_desc = - make_ConstantTensorDescriptor(in_block_src_desc.GetLengths()); - constexpr auto wei_block_dst_desc = - make_ConstantTensorDescriptor(wei_block_src_desc.GetLengths()); - constexpr auto out_block_dst_desc = - make_ConstantTensorDescriptor(out_block_src_desc.GetLengths()); + constexpr auto in_block_desc = make_ConstantTensorDescriptor(in_block_src_desc.GetLengths()); + constexpr auto wei_block_desc = make_ConstantTensorDescriptor(wei_block_src_desc.GetLengths()); + constexpr auto out_block_desc = make_ConstantTensorDescriptor(out_block_src_desc.GetLengths()); - constexpr unsigned in_block_size = in_block_dst_desc.GetElementSpace(); - constexpr unsigned wei_block_size = wei_block_dst_desc.GetElementSpace(); - constexpr unsigned out_block_size = out_block_dst_desc.GetElementSpace(); + constexpr unsigned in_block_size = in_block_desc.GetElementSpace(); + constexpr unsigned wei_block_size = wei_block_desc.GetElementSpace(); + constexpr unsigned out_block_size = out_block_desc.GetElementSpace(); __shared__ TFloat p_in_block[in_block_size]; __shared__ TFloat p_wei_block[wei_block_size]; @@ -104,9 +101,9 @@ __global__ void gridwise_convolution(InGlobalDesc, print_ConstantTensorDescriptor( in_block_src_desc, "gridwise_convolution: in_block_src_desc: "); print_ConstantTensorDescriptor(wei_block_src_desc, "gridwise_convolution: wei_block_src_desc: "); print_ConstantTensorDescriptor(out_block_src_desc, "gridwise_convolution: out_block_src_desc: "); - print_ConstantTensorDescriptor( in_block_dst_desc, "gridwise_convolution: in_block_dst_desc: "); - print_ConstantTensorDescriptor(wei_block_dst_desc, "gridwise_convolution: wei_block_dst_desc: "); - print_ConstantTensorDescriptor(out_block_dst_desc, "gridwise_convolution: out_block_dst_desc: "); + print_ConstantTensorDescriptor( in_block_desc, "gridwise_convolution: in_block_desc: "); + print_ConstantTensorDescriptor(wei_block_desc, "gridwise_convolution: wei_block_desc: "); + print_ConstantTensorDescriptor(out_block_desc, "gridwise_convolution: out_block_desc: "); printf("NBlockWork %u, KBlockWork %u, YBlockWork %u, XBlockWork %u \t" "block_id %u, n_block_work_id %u, k_block_work_id %u, y_block_work_id %u, " @@ -129,13 +126,13 @@ __global__ void gridwise_convolution(InGlobalDesc, // set output tensor in LDS to 0 blockwise_4d_tensor_op_unary(out_block_dst_desc, p_out_block, f_set0); + BlockSize>(out_block_desc, p_out_block, f_set0); for(unsigned c_block_work_begin = 0; c_block_work_begin < in_global_desc.GetLength(I1); c_block_work_begin += CPerBlock) @@ -144,26 +141,26 @@ __global__ void gridwise_convolution(InGlobalDesc, // copy input tensor to LDS blockwise_4d_tensor_op_binary( - in_block_src_desc, - p_in_global + in_block_src_desc.Get1dIndex(n_block_work_begin, - c_block_work_begin, - hi_block_work_begin, - wi_block_work_begin), - in_block_dst_desc, - p_in_block, - f_copy); + BlockSize>(in_block_src_desc, + p_in_global + + in_global_desc.Get1dIndex(n_block_work_begin, + c_block_work_begin, + hi_block_work_begin, + wi_block_work_begin), + in_block_desc, + p_in_block, + f_copy); // copy weight tensor to LDS blockwise_4d_tensor_op_binary( wei_block_src_desc, - p_wei_global + - wei_block_src_desc.Get1dIndex(k_block_work_begin, c_block_work_begin, 0, 0), - wei_block_dst_desc, + p_wei_global + wei_global_desc.Get1dIndex(k_block_work_begin, c_block_work_begin, 0, 0), + wei_block_desc, p_wei_block, f_copy); @@ -183,17 +179,13 @@ __global__ void gridwise_convolution(InGlobalDesc, // blockwise convolution blockwise_convolution(in_block_dst_desc, - p_in_block, - wei_block_dst_desc, - p_wei_block, - out_block_dst_desc, - p_out_block); + BlockSize>( + in_block_desc, p_in_block, wei_block_desc, p_wei_block, out_block_desc, p_out_block); #if 1 __syncthreads(); @@ -202,7 +194,7 @@ __global__ void gridwise_convolution(InGlobalDesc, // copy output tensor from LDS to device mem blockwise_4d_tensor_op_binary( - out_block_dst_desc, + out_block_desc, p_out_block, out_block_src_desc, p_out_global + - out_block_src_desc.Get1dIndex( + out_global_desc.Get1dIndex( n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin), f_copy); }