From 96ee9571e2c96ba6eb6972da1be75453d6c6e9fa Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 10 Apr 2019 18:10:18 -0500 Subject: [PATCH] tuned implicit gemm v1 for 3x3 on AMD to 82%. Fixed a bug in 4d tensor blockwise copy. --- ...icit_gemm_convolution_1_chwn_cyxk_khwn.hpp | 37 +++++++- driver/driver.hip.cpp | 8 +- src/include/blockwise_4d_tensor_op.hip.hpp | 24 ++--- src/include/blockwise_batched_gemm.hip.hpp | 88 +++++++++++++++++++ src/include/blockwise_gemm.hip.hpp | 1 + ...on_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp | 31 +++++-- 6 files changed, 163 insertions(+), 26 deletions(-) diff --git a/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp index 7a107ef0e1..decf294ab4 100644 --- a/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp @@ -78,7 +78,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, out_khwn_device_buf.ToDevice(out_khwn.mData.data()); #if 0 - // for 3x3, 34x34 + // for 3x3, 34x34, Pascal constexpr index_t NPerBlock = 16; constexpr index_t KPerBlock = 64; constexpr index_t CPerBlock = 4; @@ -111,6 +111,39 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, constexpr index_t OutThreadCopyDataPerWrite = 2; constexpr index_t BlockSize = 128; +#elif 1 + // for 3x3, 34x34, Vega 20 + constexpr index_t NPerBlock = 16; + constexpr index_t KPerBlock = 128; + constexpr index_t CPerBlock = 4; + constexpr index_t HoPerBlock = 2; + constexpr index_t WoPerBlock = 4; + + constexpr index_t NPerThread = 4; + constexpr index_t KPerThread = 8; + constexpr index_t HoPerThread = 1; + constexpr index_t WoPerThread = 2; + + constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 2; + constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + constexpr index_t InBlockCopy_ThreadPerDimC = 4; + constexpr index_t InBlockCopy_ThreadPerDimH = 4; + constexpr index_t InBlockCopy_ThreadPerDimW = 2; + constexpr index_t InBlockCopy_ThreadPerDimN = 8; + constexpr index_t InBlockCopyDataPerRead = 2; + + constexpr index_t WeiBlockCopyDataPerRead = 2; + constexpr index_t OutThreadCopyDataPerWrite = 4; + + constexpr index_t BlockSize = 256; #elif 0 // for 5x5, 36x36 constexpr index_t NPerBlock = 16; @@ -264,7 +297,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, constexpr index_t OutThreadCopyDataPerWrite = 4; constexpr index_t BlockSize = 128; -#elif 1 +#elif 0 // for 3x3, 28x28, v1, Pacal constexpr index_t NPerBlock = 32; constexpr index_t KPerBlock = 64; diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index d8a5ccf13e..54b8cb1982 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -409,13 +409,13 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 0; constexpr index_t WPad = 0; -#elif 0 +#elif 1 // 3x3, 34x34 constexpr index_t N = 64; constexpr index_t C = 256; constexpr index_t HI = 34; constexpr index_t WI = 34; - constexpr index_t K = 64; + constexpr index_t K = 128; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -511,7 +511,7 @@ int main(int argc, char* argv[]) constexpr index_t HPad = 1; constexpr index_t WPad = 1; -#elif 1 +#elif 0 // 3x3 filter, 28x28 image constexpr index_t N = 128; constexpr index_t C = 256; @@ -681,7 +681,7 @@ int main(int argc, char* argv[]) device_direct_convolution_2_vectorized_nchw_kcyx_nkhw #elif 1 device_implicit_gemm_convolution_1_chwn_cyxk_khwn -#elif 0 +#elif 1 device_implicit_gemm_convolution_2_chwn_cyxk_khwn #endif (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); diff --git a/src/include/blockwise_4d_tensor_op.hip.hpp b/src/include/blockwise_4d_tensor_op.hip.hpp index 8e26c9b7ca..444190e2a7 100644 --- a/src/include/blockwise_4d_tensor_op.hip.hpp +++ b/src/include/blockwise_4d_tensor_op.hip.hpp @@ -646,6 +646,9 @@ struct Blockwise4dTensorCopy3 constexpr index_t nloop_d2 = L2 / thread_per_d2; constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead); + constexpr auto clipboard_desc = make_ConstantTensorDescriptor( + Sequence{}); + #pragma unroll for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0) { @@ -664,13 +667,10 @@ struct Blockwise4dTensorCopy3 iloop_d2 * thread_per_d2, iloop_d3 * thread_per_d3 * DataPerRead); - const index_t dst_offset = - DstDesc{}.Get1dIndex(iloop_d0 * thread_per_d0, - iloop_d1 * thread_per_d1, - iloop_d2 * thread_per_d2, - iloop_d3 * thread_per_d3 * DataPerRead); + const index_t clipboard_offset = clipboard_desc.Get1dIndex( + iloop_d0, iloop_d1, iloop_d2, iloop_d3 * DataPerRead); - *(reinterpret_cast(&p_clipboard[dst_offset])) = + *(reinterpret_cast(&p_clipboard[clipboard_offset])) = *(reinterpret_cast( &p_src[src_offset + mSrcMyThreadOffset])); } @@ -713,6 +713,9 @@ struct Blockwise4dTensorCopy3 constexpr index_t nloop_d2 = L2 / thread_per_d2; constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead); + constexpr auto clipboard_desc = make_ConstantTensorDescriptor( + Sequence{}); + #pragma unroll for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0) { @@ -725,11 +728,8 @@ struct Blockwise4dTensorCopy3 #pragma unroll for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3) { - const index_t src_offset = - SrcDesc{}.Get1dIndex(iloop_d0 * thread_per_d0, - iloop_d1 * thread_per_d1, - iloop_d2 * thread_per_d2, - iloop_d3 * thread_per_d3 * DataPerRead); + const index_t clipboard_offset = clipboard_desc.Get1dIndex( + iloop_d0, iloop_d1, iloop_d2, iloop_d3 * DataPerRead); const index_t dst_offset = DstDesc{}.Get1dIndex(iloop_d0 * thread_per_d0, @@ -738,7 +738,7 @@ struct Blockwise4dTensorCopy3 iloop_d3 * thread_per_d3 * DataPerRead); *(reinterpret_cast(&p_dst[dst_offset + mDstMyThreadOffset])) = - *(reinterpret_cast(&p_clipboard[src_offset])); + *(reinterpret_cast(&p_clipboard[clipboard_offset])); } } } diff --git a/src/include/blockwise_batched_gemm.hip.hpp b/src/include/blockwise_batched_gemm.hip.hpp index 87f17532d4..6919325f98 100644 --- a/src/include/blockwise_batched_gemm.hip.hpp +++ b/src/include/blockwise_batched_gemm.hip.hpp @@ -263,6 +263,94 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 } } +#if DEVICE_BACKEND_HIP + template + __device__ void Run_asm(const FloatA* __restrict__ p_a_block, + const FloatB* __restrict__ p_b_block, + FloatC* __restrict__ p_c_thread) const + { + constexpr auto True = integral_constant{}; + constexpr auto False = integral_constant{}; + + constexpr auto a_block_mtx = BlockMatrixA{}; + constexpr auto b_block_mtx = BlockMatrixB{}; + constexpr auto c_thread_mtx = ThreadMatrixC{}; + + constexpr index_t M = a_block_mtx.NCol(); + constexpr index_t N = b_block_mtx.NCol(); + constexpr index_t K = a_block_mtx.NRow(); // A is transposed + + constexpr index_t MPerThread = c_thread_mtx.NRow(); + constexpr index_t NPerThread = c_thread_mtx.NCol(); + + // thread A, B for GEMM + // A is transposed, b is not + constexpr auto a_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); + + constexpr auto b_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); + + // thread A-sub, B-sub for copy + constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; + FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; + + constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + + // assertion for inline asm + static_assert(is_same::value && is_same::value && + is_same::value, + "Run_asm only deal with float\n"); + + static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 && + MPerThread == 8 && NPerThread == 8, + "Run_asm cannot deal with this GEMM shape yet\n"); + + static_assert( + BlockMatrixStrideA == 0 && BatchPerThread == 1, + "Run_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == 1 for now\n"); + + using Float4 = vector_type::MemoryType; + + Float4* reg_a = (Float4*)(p_a_thread); + Float4* reg_b = (Float4*)(p_b_thread); + Float4* reg_c = (Float4*)(p_c_thread); + + reg_a[0] = *reinterpret_cast(&p_a_block[mMyThreadOffsetA]); + reg_b[0] = *reinterpret_cast(&p_b_block[mMyThreadOffsetB]); + reg_b[1] = + *reinterpret_cast(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]); + reg_a[1] = + *reinterpret_cast(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]); + outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); + outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); + +#pragma unroll + for(index_t k = 1; k < K; ++k) + { + reg_a[0] = *reinterpret_cast(&p_a_block[mMyThreadOffsetA + k * M]); + outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); + reg_b[0] = *reinterpret_cast(&p_b_block[mMyThreadOffsetB + k * N]); + outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); + reg_b[1] = *reinterpret_cast( + &p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]); + reg_a[1] = *reinterpret_cast( + &p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]); + outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); + outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); + } + outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); + outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); + } +#endif + template __device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread, FloatC* __restrict__ p_c_block) const diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index 3a4d34faf0..ad4e9d2cdf 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -127,6 +127,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 } #if DEVICE_BACKEND_HIP + // TODO: this is not working correctly template __device__ void Run_asm(const FloatA* __restrict__ p_a_block, const FloatB* __restrict__ p_b_block, diff --git a/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp index 8556281669..eab42ed07b 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp @@ -204,23 +204,38 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), __syncthreads()) { - // input: global mem to LDS +#if 1 blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); - - // weight: global mem to LDS blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); +#else + Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; + Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; + + blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset, + p_in_register_clipboard); + blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, + p_wei_register_clipboard); + + blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block); + blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block); +#endif __syncthreads(); - // a series of batched GEMM +#pragma unroll for(index_t y = 0; y < Y; ++y) { +#pragma unroll for(index_t x = 0; x < X; ++x) { - blockwise_batch_gemm.Run(p_wei_block + - wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), - p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0), - p_out_thread); +#if 1 + blockwise_batch_gemm.Run +#else + blockwise_batch_gemm.Run_asm +#endif + (p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), + p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0), + p_out_thread); } } }