diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index ecf0993d59..62d15fd3a9 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -402,18 +402,18 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 auto src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA; auto dst_index = a_thread_sub_mtx.Get1dIndex(0, 0); - //const float4* loc = (const float4 *)(p_a_block + src_index); + const float4* loc = (const float4 *)(p_a_block + src_index); float4* reg = (float4 *)(p_a_thread + dst_index); - //reg[0] = loc[0]; - //reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4]; - asm volatile("\n \ - ds_read2_b64 %0, %2 offset1:1 \n \ - ds_read2_b64 %1, %2 offset0:16 offset1:17 \n \ - s_waitcnt lgkmcnt(0)" - : "=v"(reg[0]), "=v"(reg[1]) - : "v"(__to_local((void *)&p_a_block[src_index])) - ); + reg[0] = loc[0]; + reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4]; + //asm volatile("\n \ + //ds_read2_b64 %0, %2 offset1:1 \n \ + //ds_read2_b64 %1, %2 offset0:16 offset1:17 \n \ + //s_waitcnt lgkmcnt(0)" + //: "=v"(reg[0]), "=v"(reg[MPerThreadSubC/4]) + //: "v"(__to_local((void *)&p_a_block[src_index])) + //); } #endif @@ -459,16 +459,52 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 for(index_t k = 0; k < 1; ++k) { // M = 8 + const index_t bindex = b_thread_sub_mtx.Get1dIndex(k, 0); for(index_t i = 0; i < 8; ++i) { // N = 8 - for(index_t j = 0; j < 8; ++j) + const index_t aindex = a_thread_sub_mtx.Get1dIndex(k, i); // A is transposed + const index_t cindex = c_thread_mtx.Get1dIndex(i, 0); + //for(index_t j = 0; j < 8; ++j) { - const index_t aindex = a_thread_sub_mtx.Get1dIndex(k, i); // A is transposed - const index_t bindex = b_thread_sub_mtx.Get1dIndex(k, j); - const index_t cindex = c_thread_mtx.Get1dIndex(i, j); - p_c_thread[cindex] += p_a_thread[aindex] * p_b_thread[bindex]; + //p_c_thread[cindex] += p_a_thread[aindex] * p_b_thread[bindex]; + asm volatile("\n \ + v_mac_f32 %0, %8, %9 \n \ + v_mac_f32 %1, %8, %10 \n \ + v_mac_f32 %2, %8, %11 \n \ + v_mac_f32 %3, %8, %12 \n \ + v_mac_f32 %4, %8, %13 \n \ + v_mac_f32 %5, %8, %14 \n \ + v_mac_f32 %6, %8, %15 \n \ + v_mac_f32 %7, %8, %16 \n \ + " + : "=v"(p_c_thread[cindex + 0]), + "=v"(p_c_thread[cindex + 1]), + "=v"(p_c_thread[cindex + 2]), + "=v"(p_c_thread[cindex + 3]), + "=v"(p_c_thread[cindex + 4]), + "=v"(p_c_thread[cindex + 5]), + "=v"(p_c_thread[cindex + 6]), + "=v"(p_c_thread[cindex + 7]) + : "v"(p_a_thread[aindex]), + "v"(p_b_thread[bindex + 0]), + "v"(p_b_thread[bindex + 1]), + "v"(p_b_thread[bindex + 2]), + "v"(p_b_thread[bindex + 3]), + "v"(p_b_thread[bindex + 4]), + "v"(p_b_thread[bindex + 5]), + "v"(p_b_thread[bindex + 6]), + "v"(p_b_thread[bindex + 7]) + "0"(p_c_thread[cindex + 0]), + "1"(p_c_thread[cindex + 1]), + "2"(p_c_thread[cindex + 2]), + "3"(p_c_thread[cindex + 3]), + "4"(p_c_thread[cindex + 4]), + "5"(p_c_thread[cindex + 5]), + "6"(p_c_thread[cindex + 6]), + "7"(p_c_thread[cindex + 7]) + ); } } }