experimenting

[ROCm/composable_kernel commit: 766b0a9eaf]
This commit is contained in:
Chao Liu
2019-03-24 12:09:57 -05:00
parent 6f1651f8a7
commit cd883e7581
33 changed files with 1886 additions and 1822 deletions

View File

@@ -10,7 +10,7 @@ void device_direct_convolution_1(InDesc,
const Tensor<T>& wei,
OutDesc,
Tensor<T>& out,
unsigned nrepeat)
index_t nrepeat)
{
std::size_t data_sz = sizeof(T);
DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace());
@@ -34,28 +34,28 @@ void device_direct_convolution_1(InDesc,
#if 1
// 3x3, 34x34
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 16;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 16;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
(out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) *
(out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(gridwise_direct_convolution_1<T,
InDesc,

View File

@@ -10,7 +10,7 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& wei,
OutDesc,
Tensor<T>& out,
unsigned nrepeat)
index_t nrepeat)
{
std::size_t data_sz = sizeof(T);
DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace());
@@ -34,49 +34,49 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
#if 1
// 3x3, 34x34, 128 thread
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 1
// 3x3, 34x34, 128 thread, fp16
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
(out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) *
(out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time =
launch_kernel(gridwise_direct_convolution_2_nchw_kcyx_nkhw<T,

View File

@@ -10,13 +10,13 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
const Tensor<TInWei>& wei_kcyx,
OutDesc,
Tensor<TOut>& out_nkhw,
unsigned nrepeat)
index_t nrepeat)
{
// this suppose in / wei data type is int8x4
constexpr unsigned NVector = 4;
using accum_t = int32_t;
using vector_t = vector_type<TInWei, NVector>;
using vector_mem_t = typename vector_t::MemoryType;
constexpr index_t NVector = 4;
using accum_t = int32_t;
using vector_t = vector_type<TInWei, NVector>;
using vector_mem_t = typename vector_t::MemoryType;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
@@ -27,17 +27,17 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// vectorized input
auto in_nchw_vec_desc = make_ConstantTensorDescriptor(Sequence<N, C / NVector, Hi, Wi>{});
@@ -96,84 +96,84 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
#if 0
// 3x3, 34x34, 128 thread, fp32, vector = 1
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3, 34x34, 128 thread, fp32, vector = 2
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 32;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 2;
constexpr unsigned WoPerThread = 2;
constexpr index_t NPerThread = 2;
constexpr index_t KPerThread = 4;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3, 34x34, 128 thread, int8, vector = 4
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 8;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32;
constexpr index_t NPerBlock = 2;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 8;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 4;
constexpr unsigned WoPerThread = 2;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 8;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 4;
constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 1
// 1x1, 32x32, 128 thread, int8, vector = 4
constexpr unsigned NPerBlock = 1;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 16;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 32;
constexpr index_t NPerBlock = 1;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 16;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 32;
constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 8;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 4;
constexpr unsigned WoPerThread = 2;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 8;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 4;
constexpr index_t WoPerThread = 2;
constexpr unsigned InBlockCopyDataPerRead = 2;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr index_t InBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<TInWei,

View File

@@ -10,7 +10,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
unsigned nrepeat)
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
@@ -21,17 +21,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
@@ -76,218 +76,218 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
#if 0
// for 3x3, 34x34
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 4;
constexpr unsigned InBlockCopy_ThreadPerDimH = 4;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 5x5, 36x36
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopy_ThreadPerDimC = 2;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 4;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopy_ThreadPerDimC = 2;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 2;
constexpr index_t WeiBlockCopyDataPerRead = 2;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopyDataPerRead = 2; // not used, yet
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 2; // not used, yet
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 7x7, 38x38
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 1;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 1;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4; // not used, yet
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 3x3, 56x56
constexpr unsigned NPerBlock = 32;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 1x1, 28x28
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 8;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 1
// for 1x1, 14x14
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned InBlockCopy_ThreadPerDimC = 8;
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
constexpr unsigned InBlockCopy_ThreadPerDimW = 2;
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned OutThreadCopyDataPerWrite = 2;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn<GridSize,

View File

@@ -12,7 +12,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
Tensor<T>& out_nkhw,
LowerPads,
UpperPads,
unsigned nrepeat)
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
@@ -23,17 +23,17 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned N = out_nkhw_desc.GetLength(I0);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
@@ -77,177 +77,177 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(InDesc,
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
#if 0
constexpr unsigned NPerBlock = 1;
constexpr unsigned KPerBlock = 1;
constexpr unsigned CPerBlock = 1;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 1;
constexpr index_t KPerBlock = 1;
constexpr index_t CPerBlock = 1;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 1;
constexpr unsigned KPerThread = 1;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 1;
constexpr index_t KPerThread = 1;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 1;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 1;
constexpr index_t WeiBlockCopyThreadPerDim0 = 1;
constexpr index_t WeiBlockCopyThreadPerDim1 = 1;
constexpr unsigned BlockSize = 8;
constexpr index_t BlockSize = 8;
#elif 1
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 3x3 58x58, NKC = 16,256,128
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 5x5, 36x36
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 7x7, 38x38
constexpr unsigned NPerBlock = 8;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 8;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 3x3, 56x56
constexpr unsigned NPerBlock = 32;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 1
// 3x3 56x56, NKC = 16,256,128, with padding
// 3x3 28x28, NKC = 16,512,256, with padding
// 3x3 20x84, NKC = 16,256,256, with padding
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 2;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 64;
constexpr index_t WeiBlockCopyThreadPerDim0 = 2;
constexpr index_t WeiBlockCopyThreadPerDim1 = 64;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 5x5 filter, 20x84 image, 1x1 padding
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 1;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 1;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 5x5 filter, 28x28 image, 2x2 padding
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 2;
constexpr unsigned HoPerBlock = 4;
constexpr unsigned WoPerBlock = 4;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 32;
constexpr index_t CPerBlock = 2;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 4;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 1;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// for 1x1, 28x28
constexpr unsigned NPerBlock = 16;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr unsigned HoPerBlock = 2;
constexpr unsigned WoPerBlock = 2;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 2;
constexpr unsigned NPerThread = 4;
constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 2;
constexpr unsigned HoPerThread = 1;
constexpr unsigned WoPerThread = 1;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t CPerThread = 2;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 1;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 32;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded<GridSize,

View File

@@ -11,7 +11,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
unsigned nrepeat)
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
@@ -22,19 +22,19 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr unsigned N = in_nchw_desc.GetLength(I0);
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
constexpr index_t N = in_nchw_desc.GetLength(I0);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr unsigned K = wei_kcyx_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1);
constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
// convert in_nchw to in_cnhw
auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence<C, Hi, Wi, N>{});
@@ -71,128 +71,158 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
#if 0
// 3x3, 34x34
// need to use register double buffer for GEMM
constexpr unsigned BPerBlock = 128;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 4;
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 8;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 1x1, 28x28, 64 threads
constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 8;
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 8;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 2;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 64;
#elif 1
constexpr index_t BlockSize = 64;
#elif 0
// 1x1, 28x28, 128 threads, no lds-double-buffer
// 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128
constexpr unsigned BPerBlock = 64;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 2;
constexpr unsigned GemmMLevel1Cluster = 4;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif 0
// 1x1, 28x28, 256 thread
constexpr unsigned BPerBlock = 128;
constexpr unsigned KPerBlock = 128;
constexpr unsigned CPerBlock = 8;
constexpr index_t BPerBlock = 128;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr unsigned BPerThread = 8;
constexpr unsigned KPerThread = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr unsigned GemmMPerThreadSubC = 4;
constexpr unsigned GemmNPerThreadSubC = 4;
constexpr unsigned GemmMLevel0Cluster = 4;
constexpr unsigned GemmNLevel0Cluster = 4;
constexpr unsigned GemmMLevel1Cluster = 4;
constexpr unsigned GemmNLevel1Cluster = 4;
constexpr unsigned GemmKPerThreadLoop = 1;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
constexpr unsigned GemmThreadPerRowPerCluster = 8;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr unsigned InBlockCopyDataPerRead = 4;
constexpr unsigned WeiBlockCopyDataPerRead = 4;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr unsigned BlockSize = 256;
constexpr index_t BlockSize = 256;
#elif 1
// 1x1, 14x14, Vega 10
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 128;
#endif
constexpr unsigned GridSize =
constexpr index_t GridSize =
((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
@@ -208,7 +238,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
for(unsigned i = 0; i < nrepeat; ++i)
for(index_t i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(
#if 1

View File

@@ -40,11 +40,11 @@ struct GeneratorTensor_Checkboard
template <class... Ts>
double operator()(Ts... Xs) const
{
std::array<unsigned long, sizeof...(Ts)> dims = {{Xs...}};
std::array<index_t, sizeof...(Ts)> dims = {{Xs...}};
return std::accumulate(dims.begin(),
dims.end(),
true,
[](bool init, unsigned long x) -> int { return init != (x % 2); })
[](bool init, index_t x) -> int { return init != (x % 2); })
? 1
: -1;
}
@@ -80,9 +80,9 @@ auto make_TensorDescriptor(TConstTensorDesc)
constexpr auto I3 = Number<3>{};
constexpr auto desc = TConstTensorDesc{};
std::initializer_list<unsigned> lengths = {
std::initializer_list<index_t> lengths = {
desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3)};
std::initializer_list<unsigned> strides = {
std::initializer_list<index_t> strides = {
desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3)};
return TensorDescriptor(lengths, strides);
@@ -95,11 +95,11 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
LowerPads,
UpperPads)
{
unsigned h_pad_low = LowerPads{}.Get(Number<0>{});
unsigned w_pad_low = LowerPads{}.Get(Number<1>{});
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
unsigned h_pad_up = UpperPads{}.Get(Number<0>{});
unsigned w_pad_up = UpperPads{}.Get(Number<1>{});
index_t h_pad_up = UpperPads{}.Get(Number<0>{});
index_t w_pad_up = UpperPads{}.Get(Number<1>{});
auto f = [&](auto n, auto k, auto ho, auto wo) {
double v = 0;
@@ -153,11 +153,11 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
std::size_t HO = out_nkhw.mDesc.GetLengths()[2];
std::size_t WO = out_nkhw.mDesc.GetLengths()[3];
unsigned h_pad_low = LowerPads{}.Get(Number<0>{});
unsigned w_pad_low = LowerPads{}.Get(Number<1>{});
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
unsigned h_pad_up = UpperPads{}.Get(Number<0>{});
unsigned w_pad_up = UpperPads{}.Get(Number<1>{});
index_t h_pad_up = UpperPads{}.Get(Number<0>{});
index_t w_pad_up = UpperPads{}.Get(Number<1>{});
std::size_t HiPerTile = HoPerTile + Y - 1;
std::size_t WiPerTile = WoPerTile + X - 1;
@@ -399,211 +399,211 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int main(int argc, char* argv[])
{
#if 0
constexpr unsigned N = 1;
constexpr unsigned C = 1;
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 1;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 1;
constexpr index_t C = 1;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 1;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3, 34x34
constexpr unsigned N = 64;
constexpr unsigned C = 256;
constexpr unsigned HI = 34;
constexpr unsigned WI = 34;
constexpr unsigned K = 64;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 34;
constexpr index_t WI = 34;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3, 56x56
constexpr unsigned N = 64;
constexpr unsigned C = 64;
constexpr unsigned HI = 56;
constexpr unsigned WI = 56;
constexpr unsigned K = 64;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 64;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
#elif 0
// 3x3, 58x58
constexpr unsigned N = 64;
constexpr unsigned C = 64;
constexpr unsigned HI = 58;
constexpr unsigned WI = 58;
constexpr unsigned K = 64;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 64;
constexpr index_t C = 64;
constexpr index_t HI = 58;
constexpr index_t WI = 58;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
#elif 0
// 5x5, 36x36
constexpr unsigned N = 64;
constexpr unsigned C = 256;
constexpr unsigned HI = 36;
constexpr unsigned WI = 36;
constexpr unsigned K = 64;
constexpr unsigned Y = 5;
constexpr unsigned X = 5;
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 36;
constexpr index_t WI = 36;
constexpr index_t K = 64;
constexpr index_t Y = 5;
constexpr index_t X = 5;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 7x7, 38x38
constexpr unsigned N = 64;
constexpr unsigned C = 256;
constexpr unsigned HI = 38;
constexpr unsigned WI = 38;
constexpr unsigned K = 64;
constexpr unsigned Y = 7;
constexpr unsigned X = 7;
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 38;
constexpr index_t WI = 38;
constexpr index_t K = 64;
constexpr index_t Y = 7;
constexpr index_t X = 7;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3, 58x58
constexpr unsigned N = 16;
constexpr unsigned C = 128;
constexpr unsigned HI = 58;
constexpr unsigned WI = 58;
constexpr unsigned K = 256;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 16;
constexpr index_t C = 128;
constexpr index_t HI = 58;
constexpr index_t WI = 58;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
#elif 0
// 3x3 filter, 58x58 image, 0x0 padding
constexpr unsigned N = 16;
constexpr unsigned C = 128;
constexpr unsigned HI = 58;
constexpr unsigned WI = 58;
constexpr unsigned K = 256;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 16;
constexpr index_t C = 128;
constexpr index_t HI = 58;
constexpr index_t WI = 58;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3 filter, 56x56 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 128;
constexpr unsigned HI = 56;
constexpr unsigned WI = 56;
constexpr unsigned K = 256;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 16;
constexpr index_t C = 128;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 3x3 filter, 28x28 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 256;
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 512;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 16;
constexpr index_t C = 256;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 1x1 filter, 28x28 image
constexpr unsigned N = 16;
constexpr unsigned C = 256;
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 512;
constexpr unsigned Y = 1;
constexpr unsigned X = 1;
constexpr index_t N = 16;
constexpr index_t C = 256;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 3x3 filter, 20x84 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 256;
constexpr unsigned HI = 20;
constexpr unsigned WI = 84;
constexpr unsigned K = 256;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 16;
constexpr index_t C = 256;
constexpr index_t HI = 20;
constexpr index_t WI = 84;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 3x3 filter, 112x112 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 64;
constexpr unsigned HI = 112;
constexpr unsigned WI = 112;
constexpr unsigned K = 128;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr index_t N = 16;
constexpr index_t C = 64;
constexpr index_t HI = 112;
constexpr index_t WI = 112;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 5x5 filter, 20x86 image, 1x1 padding
constexpr unsigned N = 16;
constexpr unsigned C = 256;
constexpr unsigned HI = 20;
constexpr unsigned WI = 86;
constexpr unsigned K = 512;
constexpr unsigned Y = 5;
constexpr unsigned X = 5;
constexpr index_t N = 16;
constexpr index_t C = 256;
constexpr index_t HI = 20;
constexpr index_t WI = 86;
constexpr index_t K = 512;
constexpr index_t Y = 5;
constexpr index_t X = 5;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
constexpr index_t HPad = 1;
constexpr index_t WPad = 1;
#elif 0
// 5x5 filter, 28x28 image, 2x2 padding
constexpr unsigned N = 16;
constexpr unsigned C = 192;
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 32;
constexpr unsigned Y = 5;
constexpr unsigned X = 5;
constexpr index_t N = 16;
constexpr index_t C = 192;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 32;
constexpr index_t Y = 5;
constexpr index_t X = 5;
constexpr unsigned HPad = 2;
constexpr unsigned WPad = 2;
constexpr index_t HPad = 2;
constexpr index_t WPad = 2;
#elif 0
// 1x1 filter, 32x32 image
constexpr unsigned N = 64;
constexpr unsigned C = 256;
constexpr unsigned HI = 32;
constexpr unsigned WI = 32;
constexpr unsigned K = 512;
constexpr unsigned Y = 1;
constexpr unsigned X = 1;
constexpr index_t N = 64;
constexpr index_t C = 256;
constexpr index_t HI = 32;
constexpr index_t WI = 32;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
// 1x1 filter, 14x14 image
constexpr unsigned N = 128;
constexpr unsigned C = 2048;
constexpr unsigned HI = 14;
constexpr unsigned WI = 14;
constexpr unsigned K = 512;
constexpr unsigned Y = 1;
constexpr unsigned X = 1;
// 1x1 filter, 14x14 image, C = 2048
constexpr index_t N = 128;
constexpr index_t C = 2048;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
// 1x1 filter, 14x14 image, C = 512
constexpr unsigned N = 128;
constexpr unsigned C = 512;
constexpr unsigned HI = 14;
constexpr unsigned WI = 14;
constexpr unsigned K = 512;
constexpr unsigned Y = 1;
constexpr unsigned X = 1;
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#endif
auto lower_pads = Sequence<HPad, WPad>{};
@@ -634,7 +634,7 @@ int main(int argc, char* argv[])
}
bool do_verification = atoi(argv[1]);
unsigned nrepeat = atoi(argv[2]);
index_t nrepeat = atoi(argv[2]);
if(do_verification)
{