diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index 62d15fd3a9..2f29ca9f18 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -368,8 +368,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( Number{}, Number{}, Number{}); - FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; - FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; + float p_thread[a_thread_mtx.GetElementSpace() + b_thread_mtx.GetElementSpace()]; + + FloatA *p_a_thread = p_thread; + FloatB *p_b_thread = p_thread + a_thread_mtx.GetElementSpace(); constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; @@ -381,6 +383,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 // loop over k for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) { +#if 0 // copy A-sub to form A #if 0 #pragma unroll @@ -406,13 +409,14 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 float4* reg = (float4 *)(p_a_thread + dst_index); reg[0] = loc[0]; - reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4]; + reg[1] = loc[16]; + //reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4]; //asm volatile("\n \ //ds_read2_b64 %0, %2 offset1:1 \n \ - //ds_read2_b64 %1, %2 offset0:16 offset1:17 \n \ + //ds_read2_b64 %1, %2 offset0:32 offset1:33 \n \ //s_waitcnt lgkmcnt(0)" - //: "=v"(reg[0]), "=v"(reg[MPerThreadSubC/4]) - //: "v"(__to_local((void *)&p_a_block[src_index])) + //: "=v"(reg[0]), "=v"(reg[1]) + //: "v"(__to_local((void *)(loc))) //); } #endif @@ -439,8 +443,43 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 float4* reg = (float4 *)(p_b_thread + dst_index); reg[0] = loc[0]; - reg[NPerThreadSubC/4] = loc[NPerLevel1Cluster/4]; + reg[1] = loc[8]; + //reg[NPerThreadSubC/4] = loc[NPerLevel1Cluster/4]; + //asm volatile("\n \ + //ds_read2_b64 %0, %2 offset1:1 \n \ + //ds_read2_b64 %1, %2 offset0:16 offset1:17 \n \ + //s_waitcnt lgkmcnt(0)" + //: "=v"(reg[0]), "=v"(reg[1]) + //: "v"(__to_local((void *)(loc))) + //); } +#endif + +#else + auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA; + auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB; + auto dst_index = a_thread_sub_mtx.Get1dIndex(0, 0); + + const float4* a_loc = (const float4 *)(p_a_block + a_src_index); + const float4* b_loc = (const float4 *)(p_b_block + b_src_index); + float4* reg = (float4 *)(p_a_thread + dst_index); + + //reg[0] = a_loc[0]; + //reg[1] = a_loc[16]; + //reg[2] = b_loc[0]; + //reg[3] = b_loc[8]; + //s_waitcnt lgkmcnt(0) // 000000001398: BF8CC07F + asm volatile("\n \ + ds_read2_b64 %0, %4 offset1:1 \n \ + ds_read2_b64 %1, %4 offset0:32 offset1:33 \n \ + ds_read2_b64 %2, %5 offset1:1 \n \ + ds_read2_b64 %3, %5 offset0:16 offset1:17 \n \ + s_waitcnt lgkmcnt(0)" + : "=v"(reg[0]), "=v"(reg[1]), "=v"(reg[2]), "=v"(reg[3]) + : "v"(__to_local((void *)(a_loc))), "v"(__to_local((void *)(b_loc))) + ); + + #endif // C = A * B @@ -495,7 +534,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 "v"(p_b_thread[bindex + 4]), "v"(p_b_thread[bindex + 5]), "v"(p_b_thread[bindex + 6]), - "v"(p_b_thread[bindex + 7]) + "v"(p_b_thread[bindex + 7]), "0"(p_c_thread[cindex + 0]), "1"(p_c_thread[cindex + 1]), "2"(p_c_thread[cindex + 2]),