From c5e5a9307bf12034a529fa50558da16d844374ed Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 30 Jul 2019 12:10:28 -0500 Subject: [PATCH] retune implicit gemm v4r1 --- ...tion_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp | 36 ++++++++++++++++++- driver/src/driver.cpp | 8 ++--- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index e9046bd13a..6b7e1c4451 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -59,7 +59,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, constexpr index_t B = (N * Ho * Wo) / (N1 * N2); -#if 1 +#if 0 constexpr index_t BlockSize = 256; constexpr index_t BPerBlock = 16; @@ -93,6 +93,40 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; +#elif 1 + constexpr index_t BlockSize = 256; + + constexpr index_t BPerBlock = 16; + constexpr index_t KPerBlock = 64; + constexpr index_t EPerBlock = 8; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 1; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; + + using WeiBlockCopySubLengths_E_K = Sequence<2, 1>; + using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; #elif 0 constexpr index_t BlockSize = 256; diff --git a/driver/src/driver.cpp b/driver/src/driver.cpp index 56d20bc20a..02abdae973 100644 --- a/driver/src/driver.cpp +++ b/driver/src/driver.cpp @@ -72,11 +72,11 @@ int main(int argc, char* argv[]) using namespace ck; #if 0 - constexpr index_t N = 2; - constexpr index_t C = 16; + constexpr index_t N = 256; + constexpr index_t C = 1536; constexpr index_t HI = 8; constexpr index_t WI = 8; - constexpr index_t K = 128; + constexpr index_t K = 512; constexpr index_t Y = 1; constexpr index_t X = 1; @@ -532,7 +532,7 @@ int main(int argc, char* argv[]) #elif 0 device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw( (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); -#elif 0 +#elif 1 device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, in_nchw, wei_kcyx_desc,