From 6a3f3f951d8a4b8acfe8344e87fe7802e758931a Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 3 Apr 2019 17:08:14 -0500 Subject: [PATCH] add --- ...icit_gemm_convolution_2_chwn_cyxk_khwn.hpp | 6 +++--- driver/driver.hip.cpp | 2 +- src/include/blockwise_gemm.hip.hpp | 20 +++++++++++-------- ..._gemm_convolution_2_chwn_cyxk_khwn.hip.hpp | 3 +-- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp index bf7cdc8c5a..3690622c86 100644 --- a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp @@ -224,7 +224,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, // 1x1, 14x14, Vega 20, hack CPerBlock = 1 constexpr index_t BPerBlock = 64; constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 1; + constexpr index_t CPerBlock = 8; constexpr index_t BPerThread = 8; constexpr index_t KPerThread = 8; @@ -232,7 +232,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmKPerThreadLoop = 1; @@ -249,7 +249,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, constexpr index_t InBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr index_t BlockSize = 128; + constexpr index_t BlockSize = 256; #endif constexpr index_t GridSize = diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index a83e4082c7..0ea091e607 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -580,7 +580,7 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 0 +#elif 1 // 1x1 filter, 14x14 image, C = 2048 constexpr index_t N = 128; constexpr index_t C = 2048; diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index 7b1ed63702..dcbb4e7e92 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -404,10 +404,14 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); #else int k = k_begin; - ds_read_b128(reg_a[0], a_loc, k * 512); - ds_read_b128(reg_b[0], b_loc, k * 256); - ds_read_b128(reg_b[1], b_loc, 128 + k * 256); - ds_read_b128(reg_a[1], a_loc, 256 + k * 512); + int lds_a_block_off = sizeof(Float) * M; + int lds_b_block_off = sizeof(Float) * N; + int lds_a_block_off_1 = MPerLevel1Cluster * sizeof(Float); + int lds_b_block_off_1 = NPerLevel1Cluster * sizeof(Float); + ds_read_b128(reg_a[0], a_loc, k * lds_a_block_off); + ds_read_b128(reg_b[0], b_loc, k * lds_b_block_off); + ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k * lds_b_block_off); + ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k * lds_a_block_off); lgkmcnt(2); outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); lgkmcnt(1); @@ -416,12 +420,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 for(int i = 0; i < k_chunk - 1; i++) { k = k + 1; - ds_read_b128(reg_a[0], a_loc, k * 512); + ds_read_b128(reg_a[0], a_loc, k * lds_a_block_off); outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); - ds_read_b128(reg_b[0], b_loc, k * 256); + ds_read_b128(reg_b[0], b_loc, k * lds_b_block_off); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); - ds_read_b128(reg_b[1], b_loc, 128 + k * 256); - ds_read_b128(reg_a[1], a_loc, 256 + k * 512); + ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k * lds_b_block_off); + ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k * lds_a_block_off); lgkmcnt(2); outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); lgkmcnt(1); diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp index 657c233e5e..573b5923b7 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp @@ -297,8 +297,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), - p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), - __syncthreads()) + p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), __syncthreads()) { // load data blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);