From 0de4286a4f10cd8a72339cebea9835f869f86012 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 3 Apr 2019 15:19:40 -0500 Subject: [PATCH] increase depth of pipeline --- src/include/blockwise_gemm.hip.hpp | 29 +++++++++++++++++++++++------ src/include/threadwise_gemm.hip.hpp | 10 +++------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index c83e44b31b..7b1ed63702 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -384,9 +384,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 Float4* reg_c = (Float4*)(p_c_thread); void* a_loc = (void *)(p_a_block + mMyThreadOffsetA); void* b_loc = (void *)(p_b_block + mMyThreadOffsetB); -#pragma unroll // loop over k - for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) + int k_chunk = 2; +#pragma unroll + for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop * k_chunk) { #if 0 @@ -402,15 +403,31 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); #else - ds_read_b128(reg_a[0], a_loc, k_begin * 512); - ds_read_b128(reg_b[0], b_loc, k_begin * 256); - ds_read_b128(reg_b[1], b_loc, 128 + k_begin * 256); - ds_read_b128(reg_a[1], a_loc, 256 + k_begin * 512); + int k = k_begin; + ds_read_b128(reg_a[0], a_loc, k * 512); + ds_read_b128(reg_b[0], b_loc, k * 256); + ds_read_b128(reg_b[1], b_loc, 128 + k * 256); + ds_read_b128(reg_a[1], a_loc, 256 + k * 512); lgkmcnt(2); outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); lgkmcnt(1); outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); lgkmcnt(0); + for(int i = 0; i < k_chunk - 1; i++) + { + k = k + 1; + ds_read_b128(reg_a[0], a_loc, k * 512); + outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); + ds_read_b128(reg_b[0], b_loc, k * 256); + outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); + ds_read_b128(reg_b[1], b_loc, 128 + k * 256); + ds_read_b128(reg_a[1], a_loc, 256 + k * 512); + lgkmcnt(2); + outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); + lgkmcnt(1); + outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); + lgkmcnt(0); + } outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); #endif diff --git a/src/include/threadwise_gemm.hip.hpp b/src/include/threadwise_gemm.hip.hpp index f1af308440..c5a7de8049 100644 --- a/src/include/threadwise_gemm.hip.hpp +++ b/src/include/threadwise_gemm.hip.hpp @@ -73,20 +73,16 @@ __device__ void threadwise_gemm(MatrixA, for(index_t k = 0; k < K; ++k) { #if 1 - for(index_t i = 0; i < M; i+=4) + for(index_t i = 0; i < M; i++) { const index_t aindex = a_mtx.Get1dIndex(k, i); // A is transposed - const Float4 *a_vec = (const Float4 *)&p_a_thread[aindex]; - for(index_t j = 0; j < N; j+=4) + for(index_t j = 0; j < N; j++) { const index_t bindex = b_mtx.Get1dIndex(k, j); const index_t cindex = c_mtx.Get1dIndex(i, j); - const Float4 *b_vec = (const Float4 *)&p_b_thread[bindex]; - Float4 *c_vec = (Float4 *)&p_c_thread[cindex]; - - outerProduct4x4(a_vec[0], b_vec[0], c_vec[0], c_vec[2], c_vec[4], c_vec[6]); + p_c_thread[cindex] += p_a_thread[aindex] * p_b_thread[bindex]; } } #else