From 1c4ef23cff46f627ea22c8e2afc68218017f2523 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 9 Aug 2019 22:48:28 -0500 Subject: [PATCH] cleaning up --- ..._v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp | 23 ++++++++++--------- ...tion_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp | 18 +++++++-------- driver/src/driver.cpp | 2 +- 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp index 491e9a0914..8c172111f3 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -1,5 +1,5 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER +#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP +#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP #include "common_header.hpp" #include "ConstantTensorDescriptor.hpp" @@ -23,8 +23,7 @@ template + index_t WeiBlockCopyDstDataPerWrite_K> struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer { - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const + __device__ void __launch_bounds__(BlockSize, 2) + Run(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) const { // this is a mess // TODO: find more elegent way of specifying (or calculating) performance parameters - static_assert(N2 == GemmNPerThreadSubC, "wrong!"); + constexpr index_t N1 = GemmNRepeat; + constexpr index_t N2 = GemmNPerThreadSubC; + static_assert((N1 * N2 * BPerBlock) % (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0, @@ -464,4 +465,4 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer }; } // namespace ck -#endif +#endif // CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP diff --git a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index 1aa4590488..3b37d08132 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -54,11 +54,6 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); - constexpr index_t N1 = 2; - constexpr index_t N2 = 4; - - constexpr index_t B = (N * Ho * Wo) / (N1 * N2); - #if 1 // each thread hold 64 data constexpr index_t BlockSize = 256; @@ -67,6 +62,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, constexpr index_t KPerBlock = 128; constexpr index_t EPerBlock = 8; + constexpr index_t GemmNRepeat = 2; + constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmMLevel0Cluster = 4; @@ -168,6 +165,11 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; #endif + constexpr index_t N1 = GemmNRepeat; + constexpr index_t N2 = GemmNPerThreadSubC; + + constexpr index_t B = (N * Ho * Wo) / (N1 * N2); + constexpr index_t GridSize = ((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); @@ -192,8 +194,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, BPerBlock, KPerBlock, EPerBlock, - N1, - N2, + GemmNRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC, GemmMLevel0Cluster, @@ -216,8 +217,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, WeiBlockCopySrcAccessOrder, WeiBlockCopyDstAccessOrder, WeiBlockCopySrcDataPerRead_E, - WeiBlockCopyDstDataPerWrite_K, - OutThreadCopyDataPerAccess_W>{}; + WeiBlockCopyDstDataPerWrite_K>{}; float time = launch_kernel(run_gridwise_convolution_kernel, dim3(GridSize), diff --git a/driver/src/driver.cpp b/driver/src/driver.cpp index 7110a1a45e..64892c74b2 100644 --- a/driver/src/driver.cpp +++ b/driver/src/driver.cpp @@ -379,7 +379,7 @@ int main(int argc, char* argv[]) #elif 0 device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw( (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); -#elif 0 +#elif 1 device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, in_nchw, wei_kcyx_desc,