diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index b20d1eebbd..95998cf040 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -458,16 +458,44 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 #else auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA; auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB; - auto dst_index = a_thread_sub_mtx.Get1dIndex(0, 0); const float4* a_loc = (const float4 *)(p_a_block + a_src_index); const float4* b_loc = (const float4 *)(p_b_block + b_src_index); - float4* reg = (float4 *)(p_a_thread + dst_index); + float4* reg = (float4 *)(p_thread); reg[0] = a_loc[0]; reg[1] = a_loc[16]; reg[2] = b_loc[0]; reg[3] = b_loc[8]; + + //asm volatile("\n \ + //ds_read2_b64 %0, %1 offset1:1 \n \ + //s_waitcnt lgkmcnt(0)" + //: "=v"(reg[0]) + //: "v"(__to_local((void *)(a_loc))) + //); + + //asm volatile("\n \ + //ds_read2_b64 %0, %1 offset1:1 \n \ + //s_waitcnt lgkmcnt(0)" + //: "=v"(reg[1]) + //: "v"(__to_local((void *)(a_loc + 16))) + //); + + //asm volatile("\n \ + //ds_read2_b64 %0, %1 offset1:1 \n \ + //s_waitcnt lgkmcnt(0)" + //: "=v"(reg[2]) + //: "v"(__to_local((void *)(b_loc))) + //); + + //asm volatile("\n \ + //ds_read2_b64 %0, %1 offset1:1 \n \ + //s_waitcnt lgkmcnt(0)" + //: "=v"(reg[3]) + //: "v"(__to_local((void *)(b_loc + 8))) + //); + //asm volatile("\n \ //ds_read2_b64 %0, %4 offset1:1 \n \ //ds_read2_b64 %1, %4 offset0:32 offset1:33 \n \ @@ -478,6 +506,49 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 //: "v"(__to_local((void *)(a_loc))), "v"(__to_local((void *)(b_loc))) //); + //asm volatile("\n \ + //ds_read_b32 %0, %16 \n \ + //ds_read_b32 %1, %16 offset:1\n \ + //ds_read_b32 %2, %16 offset:2\n \ + //ds_read_b32 %3, %16 offset:3\n \ + //ds_read_b32 %4, %17 \n \ + //ds_read_b32 %5, %17 offset:1\n \ + //ds_read_b32 %6, %17 offset:2\n \ + //ds_read_b32 %7, %17 offset:3\n \ + //ds_read_b32 %8, %18 \n \ + //ds_read_b32 %9, %18 offset:1\n \ + //ds_read_b32 %10, %18 offset:2\n \ + //ds_read_b32 %11, %18 offset:3\n \ + //ds_read_b32 %12, %19 \n \ + //ds_read_b32 %13, %19 offset:1\n \ + //ds_read_b32 %14, %19 offset:2\n \ + //ds_read_b32 %15, %19 offset:3\n \ + //s_waitcnt lgkmcnt(0)" + //: + //"=v"(p_a_thread[0]), + //"=v"(p_a_thread[1]), + //"=v"(p_a_thread[2]), + //"=v"(p_a_thread[3]), + //"=v"(p_a_thread[4]), + //"=v"(p_a_thread[5]), + //"=v"(p_a_thread[6]), + //"=v"(p_a_thread[7]), + //"=v"(p_b_thread[0]), + //"=v"(p_b_thread[1]), + //"=v"(p_b_thread[2]), + //"=v"(p_b_thread[3]), + //"=v"(p_b_thread[4]), + //"=v"(p_b_thread[5]), + //"=v"(p_b_thread[6]), + //"=v"(p_b_thread[7]) + //: + //"v"(__to_local((void *)(&p_a_block[0]))), + //"v"(__to_local((void *)(&p_a_block[64]))), + //"v"(__to_local((void *)(&p_b_block[0]))), + //"v"(__to_local((void *)(&p_b_block[32]))) + //); + + #endif