This commit is contained in:
Chao Liu
2019-04-13 15:53:03 -05:00
parent 00899f191b
commit 7d8daba741
4 changed files with 45 additions and 123 deletions

View File

@@ -111,8 +111,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#elif 1
// for 3x3, 34x34, v1r2, Pascal
#elif 0
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1
constexpr index_t NPerBlock = 4;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 8;
@@ -146,7 +146,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t BlockSize = 128;
#elif 0
// for 3x3, 34x34, Vega 20
// for 3x3, 34x34, v1r1, Vega 20
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 4;
@@ -179,94 +179,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t BlockSize = 256;
#elif 0
// for 5x5, 36x36
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t InBlockCopy_ThreadPerDimC = 2;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3 58x58
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t InBlockCopyDataPerRead = 2; // not used, yet
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 128;
#elif 0
// for 7x7, 38x38
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 2;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 128;
#elif 0
// for 3x3, 56x56, v1, Pacal
// for 3x3, 56x56, v1, Pascal
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
@@ -299,7 +212,6 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t BlockSize = 128;
#elif 0
// for 3x3, 56x56, v1r2, Pascal
// for 3x3, 34x34, v1r2, Pascal
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
@@ -321,8 +233,8 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 1;
constexpr index_t GemmDataPerReadB = 1;
constexpr index_t InBlockCopy_ThreadPerDimC = 2;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
@@ -332,7 +244,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t BlockSize = 128;
#elif 0
// for 3x3, 28x28, v1, Pacal
// for 3x3, 28x28, v1r1, Pacal
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
@@ -364,6 +276,40 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#elif 1
// for 3x3, 28x28, v1r2, Pascal
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#elif 0
// for 1x1, 28x28

View File

@@ -421,7 +421,7 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int main(int argc, char* argv[])
{
#if 1
#if 0
// 3x3, 34x34
constexpr index_t N = 64;
constexpr index_t C = 256;
@@ -454,30 +454,6 @@ int main(int argc, char* argv[])
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
#elif 0
// 5x5, 36x36
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 36;
constexpr index_t WI = 36;
constexpr index_t K = 64;
constexpr index_t Y = 5;
constexpr index_t X = 5;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 7x7, 38x38
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 38;
constexpr index_t WI = 38;
constexpr index_t K = 128;
constexpr index_t Y = 7;
constexpr index_t X = 7;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3, 58x58
constexpr index_t N = 16;
@@ -523,7 +499,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
#elif 1
// 3x3 filter, 28x28 image
constexpr index_t N = 128;
constexpr index_t C = 256;