update implicit GEMM forward v4r4 to use gridwise gemm (#9)

* updated fwd v4r4 to use gridwise gemm
* updated gridwise gemm api calls in bwd-data v1r1 and v2r1
This commit is contained in:
Chao Liu
2019-12-05 12:36:36 -06:00
committed by GitHub
parent 19a93dac05
commit e2b4c5b469
12 changed files with 599 additions and 842 deletions

View File

@@ -21,7 +21,7 @@ int main(int argc, char* argv[])
{
using namespace ck;
#if 0
#if 1
constexpr index_t N = 8;
constexpr index_t C = 128;
constexpr index_t HI = 16;

View File

@@ -43,7 +43,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
#elif 1
// 3x3, 34x34
constexpr index_t N = 64;
constexpr index_t C = 256;
@@ -250,7 +250,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
#elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128;
@@ -296,7 +296,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 0
#elif 1
// 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 128;
@@ -403,7 +403,7 @@ int main(int argc, char* argv[])
ConvStrides{},
ConvDilations{},
nrepeat);
#elif 1
#elif 0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,