load smaller weight tensor

This commit is contained in:
Chao Liu
2019-04-08 23:57:13 -05:00
parent 5b36aeadeb
commit 796f72e26e
5 changed files with 403 additions and 52 deletions

View File

@@ -4,6 +4,7 @@
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
@@ -178,33 +179,12 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t BlockSize = 128;
#elif 0
// for 7x7, 38x38
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 1;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t InBlockCopyDataPerRead = 4; // not used, yet
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 128;
#elif 1
// for 3x3, 56x56
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
// for 7x7, 38x38
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
@@ -216,19 +196,50 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 2;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 128;
#elif 1
// for 3x3, 56x56
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 2;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 128;
#elif 0
@@ -264,7 +275,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#elif 0
#elif 1
// for 1x1, 14x14, Pascal
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
@@ -306,9 +317,11 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
for(index_t i = 0; i < nrepeat; ++i)
{
constexpr auto gridwise_conv =
#if 1
#if 0
GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
#else
#elif 1
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
#elif 0
GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
#endif
<GridSize,

View File

@@ -421,13 +421,13 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
#elif 0
// 3x3, 56x56
constexpr index_t N = 64;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 64;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
@@ -454,13 +454,13 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
#elif 1
// 7x7, 38x38
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 38;
constexpr index_t WI = 38;
constexpr index_t K = 64;
constexpr index_t K = 128;
constexpr index_t Y = 7;
constexpr index_t X = 7;
@@ -644,6 +644,9 @@ int main(int argc, char* argv[])
#if 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
@@ -666,7 +669,7 @@ int main(int argc, char* argv[])
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 1
device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif 1
#elif 0
device_implicit_gemm_convolution_2_chwn_cyxk_khwn
#endif
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);