diff --git a/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp index 93e53d304c..cca1fa29c9 100644 --- a/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp @@ -77,7 +77,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); out_khwn_device_buf.ToDevice(out_khwn.mData.data()); -#if 1 +#if 0 // for 3x3, 34x34 constexpr index_t NPerBlock = 16; constexpr index_t KPerBlock = 64; @@ -230,7 +230,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, constexpr index_t OutThreadCopyDataPerWrite = 2; constexpr index_t BlockSize = 128; -#elif 1 +#elif 0 // for 3x3, 56x56, v1r2, Pascal // for 3x3, 34x34, v1r2, Pascal constexpr index_t NPerBlock = 16; @@ -263,6 +263,40 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t OutThreadCopyDataPerWrite = 4; + constexpr index_t BlockSize = 128; +#elif 1 + // for 3x3, 28x28, v1, Pacal + constexpr index_t NPerBlock = 32; + constexpr index_t KPerBlock = 64; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 2; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t InBlockCopy_ThreadPerDimC = 1; + constexpr index_t InBlockCopy_ThreadPerDimH = 4; + constexpr index_t InBlockCopy_ThreadPerDimW = 4; + constexpr index_t InBlockCopy_ThreadPerDimN = 8; + constexpr index_t InBlockCopyDataPerRead = 4; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + constexpr index_t OutThreadCopyDataPerWrite = 2; + constexpr index_t BlockSize = 128; #elif 0 // for 1x1, 28x28 diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index f9c8a3ee21..d8a5ccf13e 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -511,6 +511,18 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 1; constexpr index_t WPad = 1; +#elif 1 + // 3x3 filter, 28x28 image + constexpr index_t N = 128; + constexpr index_t C = 256; + constexpr index_t HI = 28; + constexpr index_t WI = 28; + constexpr index_t K = 512; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; #elif 0 // 1x1 filter, 28x28 image constexpr index_t N = 16; @@ -667,9 +679,9 @@ int main(int argc, char* argv[]) device_direct_convolution_2_nchw_kcyx_nkhw #elif 0 device_direct_convolution_2_vectorized_nchw_kcyx_nkhw -#elif 0 - device_implicit_gemm_convolution_1_chwn_cyxk_khwn #elif 1 + device_implicit_gemm_convolution_1_chwn_cyxk_khwn +#elif 0 device_implicit_gemm_convolution_2_chwn_cyxk_khwn #endif (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);