added implicit gemm v1r3, refactored decomposition of wei tensor (loop over y, x first, and C second) to allow easy lds double buffer on C

This commit is contained in:
Chao Liu
2019-04-19 16:46:29 -05:00
parent 5ce19234a4
commit 6d066ede00
8 changed files with 772 additions and 139 deletions

View File

@@ -3,8 +3,10 @@
#include "device.hpp"
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r1_lds_double_buffer_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
@@ -94,9 +96,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
@@ -108,10 +110,10 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
constexpr index_t BlockSize = 128;
#elif 1
#elif 0
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1
constexpr index_t NPerBlock = 4;
constexpr index_t KPerBlock = 64;
@@ -128,9 +130,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 1;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
@@ -142,7 +144,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
constexpr index_t BlockSize = 128;
#elif 0
@@ -172,14 +174,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t InBlockCopyDataPerRead_N = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 2;
constexpr index_t OutThreadCopyDataPerWrite_N = 4;
constexpr index_t BlockSize = 256;
#elif 0
// for 3x3, 56x56, v1, Pascal
// for 3x3, 56x56, v1r1, Pascal
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
@@ -195,9 +197,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
@@ -207,7 +209,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
constexpr index_t BlockSize = 128;
#elif 0
@@ -237,13 +239,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 4;
constexpr index_t BlockSize = 128;
#elif 1
#elif 0
// for 3x3, 28x28, v1r1, Pacal
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64;
@@ -260,9 +262,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
@@ -274,11 +276,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
constexpr index_t BlockSize = 128;
#elif 1
#elif 0
// for 3x3, 28x28, v1r2, Pascal
constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
@@ -290,13 +294,37 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<4, 2, 4, 4>;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
#elif 1
// for 3x3, 28x28, v1r3, Pascal
// for 3x3, 14x14, v1r3, Pascal
constexpr index_t BlockSize = 128;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
@@ -308,11 +336,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2;
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 4>;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t BlockSize = 128;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
#elif 0
// for 1x1, 28x28
// for 1x1, 28x28, v1r1, Pascal
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
@@ -329,9 +360,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
@@ -341,11 +372,11 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
constexpr index_t BlockSize = 128;
#elif 1
// for 1x1, 14x14, Pascal
#elif 0
// for 1x1, 14x14, v1r1, Pascal
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
@@ -369,10 +400,10 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
constexpr index_t BlockSize = 128;
#endif
@@ -386,12 +417,16 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
for(index_t i = 0; i < nrepeat; ++i)
{
constexpr auto gridwise_conv =
#if 1
#if 0
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
#elif 0
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer
#elif 1
GridwiseConvolutionImplicitGemm_v1r1_lds_double_buffer_chwn_cyxk_khwn
#elif 0
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
#elif 0
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
#elif 1
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
#endif
<GridSize,
BlockSize,
@@ -417,13 +452,10 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
Sequence<InBlockCopy_ThreadPerDimC,
InBlockCopy_ThreadPerDimH,
InBlockCopy_ThreadPerDimW,
InBlockCopy_ThreadPerDimN>,
InBlockCopyDataPerRead,
WeiBlockCopyDataPerRead,
OutThreadCopyDataPerWrite>{};
InBlockCopyClusterLengths_CHWN,
InBlockCopyDataPerRead_N,
WeiBlockCopyDataPerRead_K,
OutThreadCopyDataPerWrite_N>{};
float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>,
dim3(GridSize),

View File

@@ -87,13 +87,13 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 2>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>;
using InBlockReorderMapThreadCluster2SrcCluster = Sequence<1, 2, 3, 0>;
constexpr index_t InBlockReorderDataPerRead_W = 2;
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 2>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
constexpr index_t InBlockReorderDataPerRead_W = 2;
constexpr index_t InBlockReorderDataPerWrite_N = 4;
using WeiBlockCopyClusterLengths = Sequence<4, 1, 32>;
using WeiBlockCopyClusterLengths_CXK = Sequence<4, 1, 32>;
constexpr index_t WeiBlockCopyDataPerRead_C = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
@@ -137,10 +137,10 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
GemmDataPerReadB,
InBlockReorderSrcSubLengths_NCHW,
InBlockReorderSrcClusterLengths_NCHW,
InBlockReorderMapThreadCluster2SrcCluster,
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
InBlockReorderDataPerRead_W,
InBlockReorderDataPerWrite_N,
WeiBlockCopyClusterLengths,
WeiBlockCopyClusterLengths_CXK,
WeiBlockCopyDataPerRead_C,
OutThreadCopyDataPerWrite_N>{};

View File

@@ -451,60 +451,6 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3, 58x58
constexpr index_t N = 64;
constexpr index_t C = 64;
constexpr index_t HI = 58;
constexpr index_t WI = 58;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
#elif 0
// 3x3, 58x58
constexpr index_t N = 16;
constexpr index_t C = 128;
constexpr index_t HI = 58;
constexpr index_t WI = 58;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
#elif 0
// 3x3 filter, 58x58 image, 0x0 padding
constexpr index_t N = 16;
constexpr index_t C = 128;
constexpr index_t HI = 58;
constexpr index_t WI = 58;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3 filter, 56x56 image, 1x1 padding
constexpr index_t N = 16;
constexpr index_t C = 128;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 3x3 filter, 28x28 image, 1x1 padding
constexpr index_t N = 16;
constexpr index_t C = 256;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 1
// 3x3 filter, 28x28 image
constexpr index_t N = 128;
@@ -578,31 +524,19 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 2;
constexpr index_t WPad = 2;
#elif 0
// 1x1 filter, 32x32 image
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 32;
constexpr index_t WI = 32;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
// 1x1 filter, 14x14 image, C = 2048
// 3x3 filter, 14x14 image
constexpr index_t N = 128;
constexpr index_t C = 2048;
constexpr index_t C = 256;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
// 1x1 filter, 14x14 image, C = 512
#elif 0
// 1x1 filter, 14x14 image
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 14;
@@ -673,9 +607,9 @@ int main(int argc, char* argv[])
device_direct_convolution_2_nchw_kcyx_nkhw
#elif 0
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 0
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 1
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 0
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
#elif 0
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn