From 2058bec8cfe6c006409ec3d65b67229ce1c2e6f7 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Thu, 28 Mar 2019 18:47:32 -0500 Subject: [PATCH] fused functions --- src/include/blockwise_gemm.hip.hpp | 50 ++++++++++++++++++- ..._gemm_convolution_2_chwn_cyxk_khwn.hip.hpp | 2 +- src/include/threadwise_gemm.hip.hpp | 7 ++- 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index f7cb637d4e..c3a17634ec 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -379,8 +379,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 // loop over k for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) { -#pragma unroll // copy A-sub to form A +#if 0 +#pragma unroll + // MRepeat = 2 for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) { threadwise_matrix_copy( @@ -391,9 +393,22 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), a_thread_sub_mtx.GetLengths()); } +#else + { + auto src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA; + auto dst_index = a_thread_sub_mtx.Get1dIndex(0, 0); -#pragma unroll + const float4* loc = (const float4 *)(p_a_block + src_index); + float4* reg = (float4 *)(p_a_thread + dst_index); + + reg[0] = loc[0]; + reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4]; + } +#endif + +#if 0 // copy B-sub to form B +#pragma unroll for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) { threadwise_matrix_copy( @@ -404,8 +419,21 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), b_thread_sub_mtx.GetLengths()); } +#else + { + auto src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB; + auto dst_index = b_thread_sub_mtx.Get1dIndex(0, 0); + + const float4* loc = (const float4 *)(p_b_block + src_index); + float4* reg = (float4 *)(p_b_thread + dst_index); + + reg[0] = loc[0]; + reg[NPerThreadSubC/4] = loc[NPerLevel1Cluster/4]; + } +#endif // C = A * B +#if 0 threadwise_gemm(a_thread_mtx, True, p_a_thread, @@ -416,6 +444,24 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 False, p_c_thread, f_accum); +#else + for(index_t k = 0; k < 1; ++k) + { + // M = 8 + for(index_t i = 0; i < 8; ++i) + { + // N = 8 + for(index_t j = 0; j < 8; ++j) + { + const index_t aindex = a_thread_sub_mtx.Get1dIndex(k, i); // A is transposed + const index_t bindex = b_thread_sub_mtx.Get1dIndex(k, j); + const index_t cindex = c_thread_mtx.Get1dIndex(i, j); + + p_c_thread[cindex] += p_a_thread[aindex] * p_b_thread[bindex]; + } + } + } +#endif } } diff --git a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp index 08aa8f90f5..60ab9a919d 100644 --- a/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp @@ -236,7 +236,7 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric for(index_t x = 0; x < X; ++x) { auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; -#if 0 +#if 1 blockwise_gemm.Run #elif 0 blockwise_gemm.Run_asm diff --git a/src/include/threadwise_gemm.hip.hpp b/src/include/threadwise_gemm.hip.hpp index 8cf2404c63..51a2c8c62c 100644 --- a/src/include/threadwise_gemm.hip.hpp +++ b/src/include/threadwise_gemm.hip.hpp @@ -10,9 +10,11 @@ __device__ void threadwise_matrix_copy(SrcMatrix, constexpr auto src_mtx = SrcMatrix{}; constexpr auto dst_mtx = DstMatrix{}; -#if 0 +#if 1 + //NRow = 1 for(index_t i = 0; i < NRow; ++i) { + //NCol = 4 for(index_t j = 0; j < NCol; ++j) { const index_t src_index = src_mtx.Get1dIndex(i, j); @@ -76,10 +78,13 @@ __device__ void threadwise_gemm(MatrixA, constexpr index_t N = c_mtx.NCol(); constexpr index_t K = a_mtx.NRow(); // A is transposed + // K = 1 for(index_t k = 0; k < K; ++k) { + // M = 8 for(index_t i = 0; i < M; ++i) { + // N = 8 for(index_t j = 0; j < N; ++j) { const index_t aindex = a_mtx.Get1dIndex(k, i); // A is transposed