diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index 2f29ca9f18..b20d1eebbd 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -464,20 +464,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 const float4* b_loc = (const float4 *)(p_b_block + b_src_index); float4* reg = (float4 *)(p_a_thread + dst_index); - //reg[0] = a_loc[0]; - //reg[1] = a_loc[16]; - //reg[2] = b_loc[0]; - //reg[3] = b_loc[8]; - //s_waitcnt lgkmcnt(0) // 000000001398: BF8CC07F - asm volatile("\n \ - ds_read2_b64 %0, %4 offset1:1 \n \ - ds_read2_b64 %1, %4 offset0:32 offset1:33 \n \ - ds_read2_b64 %2, %5 offset1:1 \n \ - ds_read2_b64 %3, %5 offset0:16 offset1:17 \n \ - s_waitcnt lgkmcnt(0)" - : "=v"(reg[0]), "=v"(reg[1]), "=v"(reg[2]), "=v"(reg[3]) - : "v"(__to_local((void *)(a_loc))), "v"(__to_local((void *)(b_loc))) - ); + reg[0] = a_loc[0]; + reg[1] = a_loc[16]; + reg[2] = b_loc[0]; + reg[3] = b_loc[8]; + //asm volatile("\n \ + //ds_read2_b64 %0, %4 offset1:1 \n \ + //ds_read2_b64 %1, %4 offset0:32 offset1:33 \n \ + //ds_read2_b64 %2, %5 offset1:1 \n \ + //ds_read2_b64 %3, %5 offset0:16 offset1:17 \n \ + //s_waitcnt lgkmcnt(0)" + //: "=v"(reg[0]), "=v"(reg[1]), "=v"(reg[2]), "=v"(reg[3]) + //: "v"(__to_local((void *)(a_loc))), "v"(__to_local((void *)(b_loc))) + //); #endif @@ -495,58 +494,219 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 p_c_thread, f_accum); #else - for(index_t k = 0; k < 1; ++k) - { - // M = 8 - const index_t bindex = b_thread_sub_mtx.Get1dIndex(k, 0); - for(index_t i = 0; i < 8; ++i) - { - // N = 8 - const index_t aindex = a_thread_sub_mtx.Get1dIndex(k, i); // A is transposed - const index_t cindex = c_thread_mtx.Get1dIndex(i, 0); - //for(index_t j = 0; j < 8; ++j) - { - - //p_c_thread[cindex] += p_a_thread[aindex] * p_b_thread[bindex]; asm volatile("\n \ - v_mac_f32 %0, %8, %9 \n \ - v_mac_f32 %1, %8, %10 \n \ - v_mac_f32 %2, %8, %11 \n \ - v_mac_f32 %3, %8, %12 \n \ - v_mac_f32 %4, %8, %13 \n \ - v_mac_f32 %5, %8, %14 \n \ - v_mac_f32 %6, %8, %15 \n \ - v_mac_f32 %7, %8, %16 \n \ + v_mac_f32 %0, %64, %72 \n \ + v_mac_f32 %1, %64, %73 \n \ + v_mac_f32 %2, %64, %74 \n \ + v_mac_f32 %3, %64, %75 \n \ + v_mac_f32 %4, %64, %76 \n \ + v_mac_f32 %5, %64, %77 \n \ + v_mac_f32 %6, %64, %78 \n \ + v_mac_f32 %7, %64, %79 \n \ + v_mac_f32 %8, %65, %72 \n \ + v_mac_f32 %9, %65, %73 \n \ + v_mac_f32 %10, %65, %74 \n \ + v_mac_f32 %11, %65, %75 \n \ + v_mac_f32 %12, %65, %76 \n \ + v_mac_f32 %13, %65, %77 \n \ + v_mac_f32 %14, %65, %78 \n \ + v_mac_f32 %15, %65, %79 \n \ + v_mac_f32 %16, %66, %72 \n \ + v_mac_f32 %17, %66, %73 \n \ + v_mac_f32 %18, %66, %74 \n \ + v_mac_f32 %19, %66, %75 \n \ + v_mac_f32 %20, %66, %76 \n \ + v_mac_f32 %21, %66, %77 \n \ + v_mac_f32 %22, %66, %78 \n \ + v_mac_f32 %23, %66, %79 \n \ + v_mac_f32 %24, %67, %72 \n \ + v_mac_f32 %25, %67, %73 \n \ + v_mac_f32 %26, %67, %74 \n \ + v_mac_f32 %27, %67, %75 \n \ + v_mac_f32 %28, %67, %76 \n \ + v_mac_f32 %29, %67, %77 \n \ + v_mac_f32 %30, %67, %78 \n \ + v_mac_f32 %31, %67, %79 \n \ + v_mac_f32 %32, %68, %72 \n \ + v_mac_f32 %33, %68, %73 \n \ + v_mac_f32 %34, %68, %74 \n \ + v_mac_f32 %35, %68, %75 \n \ + v_mac_f32 %36, %68, %76 \n \ + v_mac_f32 %37, %68, %77 \n \ + v_mac_f32 %38, %68, %78 \n \ + v_mac_f32 %39, %68, %79 \n \ + v_mac_f32 %40, %69, %72 \n \ + v_mac_f32 %41, %69, %73 \n \ + v_mac_f32 %42, %69, %74 \n \ + v_mac_f32 %43, %69, %75 \n \ + v_mac_f32 %44, %69, %76 \n \ + v_mac_f32 %45, %69, %77 \n \ + v_mac_f32 %46, %69, %78 \n \ + v_mac_f32 %47, %69, %79 \n \ + v_mac_f32 %48, %70, %72 \n \ + v_mac_f32 %49, %70, %73 \n \ + v_mac_f32 %50, %70, %74 \n \ + v_mac_f32 %51, %70, %75 \n \ + v_mac_f32 %52, %70, %76 \n \ + v_mac_f32 %53, %70, %77 \n \ + v_mac_f32 %54, %70, %78 \n \ + v_mac_f32 %55, %70, %79 \n \ + v_mac_f32 %56, %71, %72 \n \ + v_mac_f32 %57, %71, %73 \n \ + v_mac_f32 %58, %71, %74 \n \ + v_mac_f32 %59, %71, %75 \n \ + v_mac_f32 %60, %71, %76 \n \ + v_mac_f32 %61, %71, %77 \n \ + v_mac_f32 %62, %71, %78 \n \ + v_mac_f32 %63, %71, %79 \n \ " - : "=v"(p_c_thread[cindex + 0]), - "=v"(p_c_thread[cindex + 1]), - "=v"(p_c_thread[cindex + 2]), - "=v"(p_c_thread[cindex + 3]), - "=v"(p_c_thread[cindex + 4]), - "=v"(p_c_thread[cindex + 5]), - "=v"(p_c_thread[cindex + 6]), - "=v"(p_c_thread[cindex + 7]) - : "v"(p_a_thread[aindex]), - "v"(p_b_thread[bindex + 0]), - "v"(p_b_thread[bindex + 1]), - "v"(p_b_thread[bindex + 2]), - "v"(p_b_thread[bindex + 3]), - "v"(p_b_thread[bindex + 4]), - "v"(p_b_thread[bindex + 5]), - "v"(p_b_thread[bindex + 6]), - "v"(p_b_thread[bindex + 7]), - "0"(p_c_thread[cindex + 0]), - "1"(p_c_thread[cindex + 1]), - "2"(p_c_thread[cindex + 2]), - "3"(p_c_thread[cindex + 3]), - "4"(p_c_thread[cindex + 4]), - "5"(p_c_thread[cindex + 5]), - "6"(p_c_thread[cindex + 6]), - "7"(p_c_thread[cindex + 7]) + : + "=v"(p_c_thread[0]), + "=v"(p_c_thread[1]), + "=v"(p_c_thread[2]), + "=v"(p_c_thread[3]), + "=v"(p_c_thread[4]), + "=v"(p_c_thread[5]), + "=v"(p_c_thread[6]), + "=v"(p_c_thread[7]), + "=v"(p_c_thread[8]), + "=v"(p_c_thread[9]), + "=v"(p_c_thread[10]), + "=v"(p_c_thread[11]), + "=v"(p_c_thread[12]), + "=v"(p_c_thread[13]), + "=v"(p_c_thread[14]), + "=v"(p_c_thread[15]), + "=v"(p_c_thread[16]), + "=v"(p_c_thread[17]), + "=v"(p_c_thread[18]), + "=v"(p_c_thread[19]), + "=v"(p_c_thread[20]), + "=v"(p_c_thread[21]), + "=v"(p_c_thread[22]), + "=v"(p_c_thread[23]), + "=v"(p_c_thread[24]), + "=v"(p_c_thread[25]), + "=v"(p_c_thread[26]), + "=v"(p_c_thread[27]), + "=v"(p_c_thread[28]), + "=v"(p_c_thread[29]), + "=v"(p_c_thread[30]), + "=v"(p_c_thread[31]), + "=v"(p_c_thread[32]), + "=v"(p_c_thread[33]), + "=v"(p_c_thread[34]), + "=v"(p_c_thread[35]), + "=v"(p_c_thread[36]), + "=v"(p_c_thread[37]), + "=v"(p_c_thread[38]), + "=v"(p_c_thread[39]), + "=v"(p_c_thread[40]), + "=v"(p_c_thread[41]), + "=v"(p_c_thread[42]), + "=v"(p_c_thread[43]), + "=v"(p_c_thread[44]), + "=v"(p_c_thread[45]), + "=v"(p_c_thread[46]), + "=v"(p_c_thread[47]), + "=v"(p_c_thread[48]), + "=v"(p_c_thread[49]), + "=v"(p_c_thread[50]), + "=v"(p_c_thread[51]), + "=v"(p_c_thread[52]), + "=v"(p_c_thread[53]), + "=v"(p_c_thread[54]), + "=v"(p_c_thread[55]), + "=v"(p_c_thread[56]), + "=v"(p_c_thread[57]), + "=v"(p_c_thread[58]), + "=v"(p_c_thread[59]), + "=v"(p_c_thread[60]), + "=v"(p_c_thread[61]), + "=v"(p_c_thread[62]), + "=v"(p_c_thread[63]) + : + "v"(p_a_thread[0]), + "v"(p_a_thread[1]), + "v"(p_a_thread[2]), + "v"(p_a_thread[3]), + "v"(p_a_thread[4]), + "v"(p_a_thread[5]), + "v"(p_a_thread[6]), + "v"(p_a_thread[7]), + "v"(p_b_thread[0]), + "v"(p_b_thread[1]), + "v"(p_b_thread[2]), + "v"(p_b_thread[3]), + "v"(p_b_thread[4]), + "v"(p_b_thread[5]), + "v"(p_b_thread[6]), + "v"(p_b_thread[7]), + "0"(p_c_thread[0]), + "1"(p_c_thread[1]), + "2"(p_c_thread[2]), + "3"(p_c_thread[3]), + "4"(p_c_thread[4]), + "5"(p_c_thread[5]), + "6"(p_c_thread[6]), + "7"(p_c_thread[7]), + "8"(p_c_thread[8]), + "9"(p_c_thread[9]), + "10"(p_c_thread[10]), + "11"(p_c_thread[11]), + "12"(p_c_thread[12]), + "13"(p_c_thread[13]), + "14"(p_c_thread[14]), + "15"(p_c_thread[15]), + "16"(p_c_thread[16]), + "17"(p_c_thread[17]), + "18"(p_c_thread[18]), + "19"(p_c_thread[19]), + "20"(p_c_thread[20]), + "21"(p_c_thread[21]), + "22"(p_c_thread[22]), + "23"(p_c_thread[23]), + "24"(p_c_thread[24]), + "25"(p_c_thread[25]), + "26"(p_c_thread[26]), + "27"(p_c_thread[27]), + "28"(p_c_thread[28]), + "29"(p_c_thread[29]), + "30"(p_c_thread[30]), + "31"(p_c_thread[31]), + "32"(p_c_thread[32]), + "33"(p_c_thread[33]), + "34"(p_c_thread[34]), + "35"(p_c_thread[35]), + "36"(p_c_thread[36]), + "37"(p_c_thread[37]), + "38"(p_c_thread[38]), + "39"(p_c_thread[39]), + "40"(p_c_thread[40]), + "41"(p_c_thread[41]), + "42"(p_c_thread[42]), + "43"(p_c_thread[43]), + "44"(p_c_thread[44]), + "45"(p_c_thread[45]), + "46"(p_c_thread[46]), + "47"(p_c_thread[47]), + "48"(p_c_thread[48]), + "49"(p_c_thread[49]), + "50"(p_c_thread[50]), + "51"(p_c_thread[51]), + "52"(p_c_thread[52]), + "53"(p_c_thread[53]), + "54"(p_c_thread[54]), + "55"(p_c_thread[55]), + "56"(p_c_thread[56]), + "57"(p_c_thread[57]), + "58"(p_c_thread[58]), + "59"(p_c_thread[59]), + "60"(p_c_thread[60]), + "61"(p_c_thread[61]), + "62"(p_c_thread[62]), + "63"(p_c_thread[63]) ); - } - } - } #endif } }