diff --git a/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp index 1332ff739f..3a68099e2e 100644 --- a/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp @@ -84,10 +84,10 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, constexpr index_t HoPerBlock = 2; constexpr index_t WoPerBlock = 4; - constexpr index_t NPerThread = 8; + constexpr index_t NPerThread = 4; constexpr index_t KPerThread = 8; constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 1; + constexpr index_t WoPerThread = 2; constexpr index_t InBlockCopy_ThreadPerDimC = 4; constexpr index_t InBlockCopy_ThreadPerDimH = 4; @@ -200,7 +200,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t BlockSize = 128; -#elif 0 +#elif 1 // for 3x3, 56x56 constexpr index_t NPerBlock = 32; constexpr index_t KPerBlock = 64; @@ -209,10 +209,26 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, constexpr index_t WoPerBlock = 2; constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 16; - constexpr index_t CPerThread = 1; + constexpr index_t KPerThread = 8; constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + + constexpr index_t InBlockCopy_ThreadPerDimC = 1; + constexpr index_t InBlockCopy_ThreadPerDimH = 4; + constexpr index_t InBlockCopy_ThreadPerDimW = 4; + constexpr index_t InBlockCopy_ThreadPerDimN = 8; + constexpr index_t InBlockCopyDataPerRead = 4; + + constexpr index_t WeiBlockCopyDataPerRead = 4; + constexpr index_t OutThreadCopyDataPerWrite = 2; constexpr index_t BlockSize = 128; #elif 0 @@ -248,7 +264,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, constexpr index_t OutThreadCopyDataPerWrite = 2; constexpr index_t BlockSize = 128; -#elif 1 +#elif 0 // for 1x1, 14x14, Pascal constexpr index_t NPerBlock = 16; constexpr index_t KPerBlock = 128; @@ -290,7 +306,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, for(index_t i = 0; i < nrepeat; ++i) { constexpr auto gridwise_conv = -#if 0 +#if 1 GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn #else GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index 6777c86370..e37b0a8d3c 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -421,7 +421,7 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 0 +#elif 1 // 3x3, 56x56 constexpr index_t N = 64; constexpr index_t C = 64; @@ -430,6 +430,9 @@ int main(int argc, char* argv[]) constexpr index_t K = 64; constexpr index_t Y = 3; constexpr index_t X = 3; + + constexpr index_t HPad = 0; + constexpr index_t WPad = 0; #elif 0 // 3x3, 58x58 constexpr index_t N = 64;