mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
in progress
This commit is contained in:
@@ -402,18 +402,18 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
auto src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA;
|
||||
auto dst_index = a_thread_sub_mtx.Get1dIndex(0, 0);
|
||||
|
||||
//const float4* loc = (const float4 *)(p_a_block + src_index);
|
||||
const float4* loc = (const float4 *)(p_a_block + src_index);
|
||||
float4* reg = (float4 *)(p_a_thread + dst_index);
|
||||
|
||||
//reg[0] = loc[0];
|
||||
//reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4];
|
||||
asm volatile("\n \
|
||||
ds_read2_b64 %0, %2 offset1:1 \n \
|
||||
ds_read2_b64 %1, %2 offset0:16 offset1:17 \n \
|
||||
s_waitcnt lgkmcnt(0)"
|
||||
: "=v"(reg[0]), "=v"(reg[1])
|
||||
: "v"(__to_local((void *)&p_a_block[src_index]))
|
||||
);
|
||||
reg[0] = loc[0];
|
||||
reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4];
|
||||
//asm volatile("\n \
|
||||
//ds_read2_b64 %0, %2 offset1:1 \n \
|
||||
//ds_read2_b64 %1, %2 offset0:16 offset1:17 \n \
|
||||
//s_waitcnt lgkmcnt(0)"
|
||||
//: "=v"(reg[0]), "=v"(reg[MPerThreadSubC/4])
|
||||
//: "v"(__to_local((void *)&p_a_block[src_index]))
|
||||
//);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -459,16 +459,52 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
for(index_t k = 0; k < 1; ++k)
|
||||
{
|
||||
// M = 8
|
||||
const index_t bindex = b_thread_sub_mtx.Get1dIndex(k, 0);
|
||||
for(index_t i = 0; i < 8; ++i)
|
||||
{
|
||||
// N = 8
|
||||
for(index_t j = 0; j < 8; ++j)
|
||||
const index_t aindex = a_thread_sub_mtx.Get1dIndex(k, i); // A is transposed
|
||||
const index_t cindex = c_thread_mtx.Get1dIndex(i, 0);
|
||||
//for(index_t j = 0; j < 8; ++j)
|
||||
{
|
||||
const index_t aindex = a_thread_sub_mtx.Get1dIndex(k, i); // A is transposed
|
||||
const index_t bindex = b_thread_sub_mtx.Get1dIndex(k, j);
|
||||
const index_t cindex = c_thread_mtx.Get1dIndex(i, j);
|
||||
|
||||
p_c_thread[cindex] += p_a_thread[aindex] * p_b_thread[bindex];
|
||||
//p_c_thread[cindex] += p_a_thread[aindex] * p_b_thread[bindex];
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %8, %9 \n \
|
||||
v_mac_f32 %1, %8, %10 \n \
|
||||
v_mac_f32 %2, %8, %11 \n \
|
||||
v_mac_f32 %3, %8, %12 \n \
|
||||
v_mac_f32 %4, %8, %13 \n \
|
||||
v_mac_f32 %5, %8, %14 \n \
|
||||
v_mac_f32 %6, %8, %15 \n \
|
||||
v_mac_f32 %7, %8, %16 \n \
|
||||
"
|
||||
: "=v"(p_c_thread[cindex + 0]),
|
||||
"=v"(p_c_thread[cindex + 1]),
|
||||
"=v"(p_c_thread[cindex + 2]),
|
||||
"=v"(p_c_thread[cindex + 3]),
|
||||
"=v"(p_c_thread[cindex + 4]),
|
||||
"=v"(p_c_thread[cindex + 5]),
|
||||
"=v"(p_c_thread[cindex + 6]),
|
||||
"=v"(p_c_thread[cindex + 7])
|
||||
: "v"(p_a_thread[aindex]),
|
||||
"v"(p_b_thread[bindex + 0]),
|
||||
"v"(p_b_thread[bindex + 1]),
|
||||
"v"(p_b_thread[bindex + 2]),
|
||||
"v"(p_b_thread[bindex + 3]),
|
||||
"v"(p_b_thread[bindex + 4]),
|
||||
"v"(p_b_thread[bindex + 5]),
|
||||
"v"(p_b_thread[bindex + 6]),
|
||||
"v"(p_b_thread[bindex + 7])
|
||||
"0"(p_c_thread[cindex + 0]),
|
||||
"1"(p_c_thread[cindex + 1]),
|
||||
"2"(p_c_thread[cindex + 2]),
|
||||
"3"(p_c_thread[cindex + 3]),
|
||||
"4"(p_c_thread[cindex + 4]),
|
||||
"5"(p_c_thread[cindex + 5]),
|
||||
"6"(p_c_thread[cindex + 6]),
|
||||
"7"(p_c_thread[cindex + 7])
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user