This commit is contained in:
Chao Liu
2019-03-16 10:50:46 -05:00
parent ce0182ce05
commit fd8de38417
6 changed files with 359 additions and 71 deletions

View File

@@ -8,7 +8,7 @@
#include "ConstantTensorDescriptor.hip.hpp"
#include "conv_common.hip.hpp"
#include "device_direct_convolution_1.hpp"
#include "device_direct_convolution_2.hpp"
#include "device_direct_convolution_2_nchw_kcyx_nkhw.hpp"
#include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp"
#include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp"
#include "device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp"
@@ -503,7 +503,7 @@ int main(int argc, char* argv[])
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
#elif 0
#elif 1
// 1x1 filter, 28x28 image
constexpr unsigned N = 16;
constexpr unsigned C = 256;
@@ -577,10 +577,11 @@ int main(int argc, char* argv[])
ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
Tensor<half> in_nchw(make_TensorDescriptor(in_nchw_desc));
Tensor<half> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc));
Tensor<half> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc));
Tensor<half> out_nkhw_device(make_TensorDescriptor(out_nkhw_desc));
using Float = float;
Tensor<Float> in_nchw(make_TensorDescriptor(in_nchw_desc));
Tensor<Float> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc));
Tensor<Float> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc));
Tensor<Float> out_nkhw_device(make_TensorDescriptor(out_nkhw_desc));
std::size_t num_thread = std::thread::hardware_concurrency();
@@ -610,9 +611,9 @@ int main(int argc, char* argv[])
#if 1
#if 0
device_direct_convolution_1
#elif 0
device_direct_convolution_2
#elif 1
device_direct_convolution_2_nchw_kcyx_nkhw
#elif 0
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif 0
device_implicit_gemm_convolution_2_chwn_cyxk_khwn
@@ -633,7 +634,7 @@ int main(int argc, char* argv[])
if(do_verification)
{
#if 0
#if 1
if(Y == 3 && X == 3)
{
host_winograd_3x3_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads);