device_implicit_gemm_convolution_1_chwn_csrk_khwn: use tensor copy (instead of pointwise) for writing output, 3x3 increased from 78% to 84%, 5x5 from 80% to 84%

This commit is contained in:
Chao Liu
2019-02-19 11:47:46 -06:00
parent 50b96745c6
commit a65ef90308
7 changed files with 795 additions and 60 deletions

View File

@@ -571,16 +571,21 @@ int main()
std::size_t num_thread = std::thread::hardware_concurrency();
bool do_verification = true;
if(do_verification)
{
#if 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcsr.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcsr.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_2{-2, 2}, num_thread);
wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
in_nchw.GenerateTensorValue(GeneratorTensor_2{-2, 2}, num_thread);
wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#endif
}
unsigned nrepeat = 200;
@@ -614,22 +619,23 @@ int main()
nrepeat);
#endif
#if 1
if(S == 3 && R == 3)
if(do_verification)
{
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
}
else
{
host_direct_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
}
check_error(out_nkhw_host, out_nkhw_device);
#endif
if(S == 3 && R == 3)
{
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
}
else
{
host_direct_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
}
check_error(out_nkhw_host, out_nkhw_device);
#if 0
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcsr: ", wei_kcsr.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcsr: ", wei_kcsr.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
#endif
}
}