cleaning up

This commit is contained in:
Chao Liu
2019-08-09 22:48:28 -05:00
parent 4908fe3fdc
commit 1c4ef23cff
3 changed files with 22 additions and 21 deletions

View File

@@ -54,11 +54,6 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
constexpr index_t N1 = 2;
constexpr index_t N2 = 4;
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
#if 1
// each thread hold 64 data
constexpr index_t BlockSize = 256;
@@ -67,6 +62,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
@@ -168,6 +165,11 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#endif
constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
constexpr index_t GridSize =
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
@@ -192,8 +194,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
BPerBlock,
KPerBlock,
EPerBlock,
N1,
N2,
GemmNRepeat,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
@@ -216,8 +217,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K,
OutThreadCopyDataPerAccess_W>{};
WeiBlockCopyDstDataPerWrite_K>{};
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),