mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
added implicit gemm v1r3, refactored decomposition of wei tensor (loop over y, x first, and C second) to allow easy lds double buffer on C
This commit is contained in:
@@ -3,8 +3,10 @@
|
||||
#include "device.hpp"
|
||||
#include "gridwise_convolution_wrapper.hip.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v1r1_lds_double_buffer_chwn_cyxk_khwn.hip.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp"
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc>
|
||||
void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
@@ -94,9 +96,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
@@ -108,10 +110,10 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1
|
||||
constexpr index_t NPerBlock = 4;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
@@ -128,9 +130,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 1;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
@@ -142,7 +144,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
@@ -172,14 +174,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 2;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 2;
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 4;
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 2;
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 4;
|
||||
|
||||
constexpr index_t BlockSize = 256;
|
||||
#elif 0
|
||||
// for 3x3, 56x56, v1, Pascal
|
||||
// for 3x3, 56x56, v1r1, Pascal
|
||||
constexpr index_t NPerBlock = 32;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
@@ -195,9 +197,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
@@ -207,7 +209,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
@@ -237,13 +239,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 4;
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 4;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// for 3x3, 28x28, v1r1, Pacal
|
||||
constexpr index_t NPerBlock = 32;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
@@ -260,9 +262,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
@@ -274,11 +276,13 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// for 3x3, 28x28, v1r2, Pascal
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
@@ -290,13 +294,37 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
using InBlockCopyClusterLengths_CHWN = Sequence<4, 2, 4, 4>;
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
#elif 1
|
||||
// for 3x3, 28x28, v1r3, Pascal
|
||||
// for 3x3, 14x14, v1r3, Pascal
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
@@ -308,11 +336,14 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 4>;
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 4;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
#elif 0
|
||||
// for 1x1, 28x28
|
||||
// for 1x1, 28x28, v1r1, Pascal
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
@@ -329,9 +360,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
@@ -341,11 +372,11 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 1
|
||||
// for 1x1, 14x14, Pascal
|
||||
#elif 0
|
||||
// for 1x1, 14x14, v1r1, Pascal
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
@@ -369,10 +400,10 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#endif
|
||||
@@ -386,12 +417,16 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
constexpr auto gridwise_conv =
|
||||
#if 1
|
||||
#if 0
|
||||
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer
|
||||
#elif 1
|
||||
GridwiseConvolutionImplicitGemm_v1r1_lds_double_buffer_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
#elif 1
|
||||
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
#endif
|
||||
<GridSize,
|
||||
BlockSize,
|
||||
@@ -417,13 +452,10 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
Sequence<InBlockCopy_ThreadPerDimC,
|
||||
InBlockCopy_ThreadPerDimH,
|
||||
InBlockCopy_ThreadPerDimW,
|
||||
InBlockCopy_ThreadPerDimN>,
|
||||
InBlockCopyDataPerRead,
|
||||
WeiBlockCopyDataPerRead,
|
||||
OutThreadCopyDataPerWrite>{};
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
InBlockCopyDataPerRead_N,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
OutThreadCopyDataPerWrite_N>{};
|
||||
|
||||
float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
|
||||
@@ -87,13 +87,13 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 2>;
|
||||
using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>;
|
||||
using InBlockReorderMapThreadCluster2SrcCluster = Sequence<1, 2, 3, 0>;
|
||||
constexpr index_t InBlockReorderDataPerRead_W = 2;
|
||||
constexpr index_t InBlockReorderDataPerWrite_N = 4;
|
||||
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 2>;
|
||||
using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>;
|
||||
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
|
||||
constexpr index_t InBlockReorderDataPerRead_W = 2;
|
||||
constexpr index_t InBlockReorderDataPerWrite_N = 4;
|
||||
|
||||
using WeiBlockCopyClusterLengths = Sequence<4, 1, 32>;
|
||||
using WeiBlockCopyClusterLengths_CXK = Sequence<4, 1, 32>;
|
||||
constexpr index_t WeiBlockCopyDataPerRead_C = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
@@ -137,10 +137,10 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
|
||||
GemmDataPerReadB,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
InBlockReorderMapThreadCluster2SrcCluster,
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N,
|
||||
WeiBlockCopyClusterLengths,
|
||||
WeiBlockCopyClusterLengths_CXK,
|
||||
WeiBlockCopyDataPerRead_C,
|
||||
OutThreadCopyDataPerWrite_N>{};
|
||||
|
||||
|
||||
@@ -451,60 +451,6 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 3x3, 58x58
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 58;
|
||||
constexpr index_t WI = 58;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
#elif 0
|
||||
// 3x3, 58x58
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 58;
|
||||
constexpr index_t WI = 58;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
#elif 0
|
||||
// 3x3 filter, 58x58 image, 0x0 padding
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 58;
|
||||
constexpr index_t WI = 58;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 3x3 filter, 56x56 image, 1x1 padding
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr index_t HPad = 1;
|
||||
constexpr index_t WPad = 1;
|
||||
#elif 0
|
||||
// 3x3 filter, 28x28 image, 1x1 padding
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr index_t HPad = 1;
|
||||
constexpr index_t WPad = 1;
|
||||
#elif 1
|
||||
// 3x3 filter, 28x28 image
|
||||
constexpr index_t N = 128;
|
||||
@@ -578,31 +524,19 @@ int main(int argc, char* argv[])
|
||||
constexpr index_t HPad = 2;
|
||||
constexpr index_t WPad = 2;
|
||||
#elif 0
|
||||
// 1x1 filter, 32x32 image
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 32;
|
||||
constexpr index_t WI = 32;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 1
|
||||
// 1x1 filter, 14x14 image, C = 2048
|
||||
// 3x3 filter, 14x14 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 2048;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 1
|
||||
// 1x1 filter, 14x14 image, C = 512
|
||||
#elif 0
|
||||
// 1x1 filter, 14x14 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 14;
|
||||
@@ -673,9 +607,9 @@ int main(int argc, char* argv[])
|
||||
device_direct_convolution_2_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
|
||||
#elif 1
|
||||
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
|
||||
|
||||
Reference in New Issue
Block a user