From 8edbc659b88ed2147984dd0f02096056ec6b89e7 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 6 Mar 2019 12:34:31 -0600 Subject: [PATCH] refactor --- driver/driver.hip.cpp | 10 ++--- src/include/threadwise_nd_tensor_op.hip.hpp | 46 +++------------------ 2 files changed, 11 insertions(+), 45 deletions(-) diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index 3b18645c4b..b5af7009ce 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -577,10 +577,10 @@ int main(int argc, char* argv[]) ostream_ConstantTensorDescriptor(wei_kcsr_desc, std::cout << "wei_kcsr_desc: "); ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); - Tensor in_nchw(make_TensorDescriptor(in_nchw_desc)); - Tensor wei_kcsr(make_TensorDescriptor(wei_kcsr_desc)); - Tensor out_nkhw_host(make_TensorDescriptor(out_nkhw_desc)); - Tensor out_nkhw_device(make_TensorDescriptor(out_nkhw_desc)); + Tensor in_nchw(make_TensorDescriptor(in_nchw_desc)); + Tensor wei_kcsr(make_TensorDescriptor(wei_kcsr_desc)); + Tensor out_nkhw_host(make_TensorDescriptor(out_nkhw_desc)); + Tensor out_nkhw_device(make_TensorDescriptor(out_nkhw_desc)); std::size_t num_thread = std::thread::hardware_concurrency(); @@ -633,7 +633,7 @@ int main(int argc, char* argv[]) if(do_verification) { -#if 1 +#if 0 if(Y == 3 && X == 3) { host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads); diff --git a/src/include/threadwise_nd_tensor_op.hip.hpp b/src/include/threadwise_nd_tensor_op.hip.hpp index 97206e88f5..c787afae77 100644 --- a/src/include/threadwise_nd_tensor_op.hip.hpp +++ b/src/include/threadwise_nd_tensor_op.hip.hpp @@ -10,8 +10,7 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc, SrcOpLengths, Number) { - using Float2 = float2; - using Float4 = float4; + using vector_t = typename vector_type::type; static_assert(SrcDesc{}.GetDimension() == 6 && DstDesc{}.GetDimension() == 6 && SrcOpLengths::nDim == 6, @@ -62,24 +61,8 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc, const unsigned dst_index = dst_desc.Get1dIndex( did0, did1, did2, did3, did4, iloop_d5 * DataPerRead); - if(DataPerRead == 1) - { - p_dst[dst_index] = p_src[src_index]; - } - else if(DataPerRead == 2) - { - *(reinterpret_cast(p_dst + dst_index)) = - *(reinterpret_cast(p_src + src_index)); - } - else if(DataPerRead == 4) - { - *(reinterpret_cast(p_dst + dst_index)) = - *(reinterpret_cast(p_src + src_index)); - } - else - { - assert(false); - } + *(reinterpret_cast(p_dst + dst_index)) = + *(reinterpret_cast(p_src + src_index)); } } } @@ -97,8 +80,7 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc, SrcOpLengths, Number) { - using Float2 = float2; - using Float4 = float4; + using vector_t = typename vector_type::type; static_assert(SrcDesc{}.GetDimension() == 8 && DstDesc{}.GetDimension() == 8 && SrcOpLengths::nDim == 8, @@ -169,24 +151,8 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc, did6, iloop_d7 * DataPerRead); - if(DataPerRead == 1) - { - p_dst[dst_index] = p_src[src_index]; - } - else if(DataPerRead == 2) - { - *(reinterpret_cast(p_dst + dst_index)) = - *(reinterpret_cast(p_src + src_index)); - } - else if(DataPerRead == 4) - { - *(reinterpret_cast(p_dst + dst_index)) = - *(reinterpret_cast(p_src + src_index)); - } - else - { - assert(false); - } + *(reinterpret_cast(p_dst + dst_index)) = + *(reinterpret_cast(p_src + src_index)); } } }