adding implicit gemm v4r4

This commit is contained in:
Chao Liu
2019-07-28 19:39:57 -05:00
parent 8669e242ad
commit 9ba3b49158
11 changed files with 1005 additions and 27 deletions

View File

@@ -16,6 +16,7 @@
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
struct GeneratorTensor_1
{
@@ -71,13 +72,16 @@ int main(int argc, char* argv[])
using namespace ck;
#if 0
constexpr index_t N = 8;
constexpr index_t N = 2;
constexpr index_t C = 16;
constexpr index_t HI = 3;
constexpr index_t WI = 18;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
@@ -249,7 +253,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
#elif 1
// 1x1 filter, 8x8 image
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
constexpr index_t N = 64;
@@ -265,7 +269,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
#elif 0
// 1x1 filter, 8x8 image
// cudnn@V100 77%, ck@V100 76%, ck@P100 79%, ck@VII 51%
constexpr index_t N = 128;
@@ -491,7 +495,7 @@ int main(int argc, char* argv[])
if(do_verification)
{
#if 1
#if 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 0
@@ -548,7 +552,7 @@ int main(int argc, char* argv[])
ConvStrides{},
ConvDilations{},
nrepeat);
#elif 1
#elif 0
device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
@@ -558,6 +562,16 @@ int main(int argc, char* argv[])
ConvStrides{},
ConvDilations{},
nrepeat);
#elif 1
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
nrepeat);
#elif 0
device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(in_nchw_desc,
in_nchw,