adding implicit gemm v3

This commit is contained in:
Chao Liu
2019-05-22 19:39:56 -05:00
parent 2a48812edb
commit 8a4b59785b
26 changed files with 373 additions and 259 deletions

View File

@@ -38,7 +38,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_packed_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
auto wei_cyxk_desc = make_ConstantTensorDescriptor_default_rank_packed(Sequence<C, Y, X, K>{});
ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: ");
Tensor<T> wei_cyxk(make_TensorDescriptor(wei_cyxk_desc));
@@ -51,7 +51,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
std::thread::hardware_concurrency());
// reorder input
auto in_chwn_desc = make_packed_ConstantTensorDescriptor(Sequence<C, Hi, Wi, N>{});
auto in_chwn_desc = make_ConstantTensorDescriptor_default_rank_packed(Sequence<C, Hi, Wi, N>{});
ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: ");
Tensor<T> in_chwn(make_TensorDescriptor(in_chwn_desc));
@@ -64,7 +64,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
std::thread::hardware_concurrency());
// output
auto out_khwn_desc = make_packed_ConstantTensorDescriptor(Sequence<K, Ho, Wo, N>{});
auto out_khwn_desc =
make_ConstantTensorDescriptor_default_rank_packed(Sequence<K, Ho, Wo, N>{});
ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: ");
Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc));

View File

@@ -37,7 +37,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_packed_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
auto wei_cyxk_desc = make_ConstantTensorDescriptor_default_rank_packed(Sequence<C, Y, X, K>{});
ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: ");
Tensor<T> wei_cyxk(make_TensorDescriptor(wei_cyxk_desc));
@@ -50,7 +50,8 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
std::thread::hardware_concurrency());
// output
auto out_khwn_desc = make_packed_ConstantTensorDescriptor(Sequence<K, Ho, Wo, N>{});
auto out_khwn_desc =
make_ConstantTensorDescriptor_default_rank_packed(Sequence<K, Ho, Wo, N>{});
ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: ");
Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc));

View File

@@ -36,7 +36,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_packed_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
auto wei_cyxk_desc = make_ConstantTensorDescriptor_default_rank_packed(Sequence<C, Y, X, K>{});
ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: ");
Tensor<T> wei_cyxk(make_TensorDescriptor(wei_cyxk_desc));
@@ -57,7 +57,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0
#if 1
// for 3x3, 34x34, v1r3, Pascal
constexpr index_t BlockSize = 128;
@@ -92,7 +92,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
#elif 0
#elif 1
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 32
constexpr index_t BlockSize = 256;
@@ -162,7 +162,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
#elif 1
#elif 0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 8
constexpr index_t BlockSize = 256;

View File

@@ -35,7 +35,7 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
// reorder weight
auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
auto wei_cyxk_desc = make_ConstantTensorDescriptor_default_rank_packed(Sequence<C, Y, X, K>{});
ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: ");
Tensor<T> wei_cyxk(make_TensorDescriptor(wei_cyxk_desc));
@@ -56,37 +56,40 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
constexpr index_t N1 = 2;
constexpr index_t N2 = 4;
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
#if 1
// for 3x3, 28x28, v3, Pascal
constexpr index_t BlockSize = 128;
// for 3x3, 28x28, v3
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t BPerThread = 1;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>;
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
using InBlockCopySubLengths_N1_N2_C_B = Sequence<1, 4, 1, 1>;
using InBlockCopyClusterLengths_N1_N2_C_B = Sequence<2, 1, 8, 16>;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
#endif
constexpr index_t GridSize =
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
@@ -102,15 +105,11 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
decltype(in_nchw_desc),
decltype(wei_cyxk_desc),
decltype(out_nkhw_desc),
NPerBlock,
BPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
HoPerThread,
WoPerThread,
N1,
N2,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
@@ -120,14 +119,11 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockReorderSrcSubLengths_NCHW,
InBlockReorderSrcClusterLengths_NCHW,
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
InBlockReorderDataPerRead_W,
InBlockReorderDataPerWrite_N,
WeiBlockCopyClusterLengths,
WeiBlockCopyDataPerRead_K,
OutThreadCopyDataPerWrite_W>{};
InBlockCopySubLengths_N1_N2_C_B,
InBlockCopyClusterLengths_N1_N2_C_B,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2,
WeiBlockCopyDataPerAccess_K>{};
float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>,
dim3(GridSize),

View File

@@ -13,7 +13,7 @@
#include "device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp"
#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
struct GeneratorTensor_1
{
@@ -548,8 +548,8 @@ int main(int argc, char* argv[])
auto lower_pads = Sequence<HPad, WPad>{};
auto upper_pads = Sequence<HPad, WPad>{};
auto in_nchw_desc = make_packed_ConstantTensorDescriptor(Sequence<N, C, HI, WI>{});
auto wei_kcyx_desc = make_packed_ConstantTensorDescriptor(Sequence<K, C, Y, X>{});
auto in_nchw_desc = make_ConstantTensorDescriptor_default_rank_packed(Sequence<N, C, HI, WI>{});
auto wei_kcyx_desc = make_ConstantTensorDescriptor_default_rank_packed(Sequence<K, C, Y, X>{});
auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor(
in_nchw_desc, wei_kcyx_desc, lower_pads, upper_pads);
@@ -612,11 +612,11 @@ int main(int argc, char* argv[])
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 0
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
#elif 1
#elif 0
device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw
#elif 0
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
#elif 0
#elif 1
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
#endif
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);