mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
in progress
This commit is contained in:
@@ -368,8 +368,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
float p_thread[a_thread_mtx.GetElementSpace() + b_thread_mtx.GetElementSpace()];
|
||||
|
||||
FloatA *p_a_thread = p_thread;
|
||||
FloatB *p_b_thread = p_thread + a_thread_mtx.GetElementSpace();
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
@@ -381,6 +383,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
// loop over k
|
||||
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
{
|
||||
#if 0
|
||||
// copy A-sub to form A
|
||||
#if 0
|
||||
#pragma unroll
|
||||
@@ -406,13 +409,14 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
float4* reg = (float4 *)(p_a_thread + dst_index);
|
||||
|
||||
reg[0] = loc[0];
|
||||
reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4];
|
||||
reg[1] = loc[16];
|
||||
//reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4];
|
||||
//asm volatile("\n \
|
||||
//ds_read2_b64 %0, %2 offset1:1 \n \
|
||||
//ds_read2_b64 %1, %2 offset0:16 offset1:17 \n \
|
||||
//ds_read2_b64 %1, %2 offset0:32 offset1:33 \n \
|
||||
//s_waitcnt lgkmcnt(0)"
|
||||
//: "=v"(reg[0]), "=v"(reg[MPerThreadSubC/4])
|
||||
//: "v"(__to_local((void *)&p_a_block[src_index]))
|
||||
//: "=v"(reg[0]), "=v"(reg[1])
|
||||
//: "v"(__to_local((void *)(loc)))
|
||||
//);
|
||||
}
|
||||
#endif
|
||||
@@ -439,8 +443,43 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
float4* reg = (float4 *)(p_b_thread + dst_index);
|
||||
|
||||
reg[0] = loc[0];
|
||||
reg[NPerThreadSubC/4] = loc[NPerLevel1Cluster/4];
|
||||
reg[1] = loc[8];
|
||||
//reg[NPerThreadSubC/4] = loc[NPerLevel1Cluster/4];
|
||||
//asm volatile("\n \
|
||||
//ds_read2_b64 %0, %2 offset1:1 \n \
|
||||
//ds_read2_b64 %1, %2 offset0:16 offset1:17 \n \
|
||||
//s_waitcnt lgkmcnt(0)"
|
||||
//: "=v"(reg[0]), "=v"(reg[1])
|
||||
//: "v"(__to_local((void *)(loc)))
|
||||
//);
|
||||
}
|
||||
#endif
|
||||
|
||||
#else
|
||||
auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA;
|
||||
auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB;
|
||||
auto dst_index = a_thread_sub_mtx.Get1dIndex(0, 0);
|
||||
|
||||
const float4* a_loc = (const float4 *)(p_a_block + a_src_index);
|
||||
const float4* b_loc = (const float4 *)(p_b_block + b_src_index);
|
||||
float4* reg = (float4 *)(p_a_thread + dst_index);
|
||||
|
||||
//reg[0] = a_loc[0];
|
||||
//reg[1] = a_loc[16];
|
||||
//reg[2] = b_loc[0];
|
||||
//reg[3] = b_loc[8];
|
||||
//s_waitcnt lgkmcnt(0) // 000000001398: BF8CC07F
|
||||
asm volatile("\n \
|
||||
ds_read2_b64 %0, %4 offset1:1 \n \
|
||||
ds_read2_b64 %1, %4 offset0:32 offset1:33 \n \
|
||||
ds_read2_b64 %2, %5 offset1:1 \n \
|
||||
ds_read2_b64 %3, %5 offset0:16 offset1:17 \n \
|
||||
s_waitcnt lgkmcnt(0)"
|
||||
: "=v"(reg[0]), "=v"(reg[1]), "=v"(reg[2]), "=v"(reg[3])
|
||||
: "v"(__to_local((void *)(a_loc))), "v"(__to_local((void *)(b_loc)))
|
||||
);
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
// C = A * B
|
||||
@@ -495,7 +534,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
"v"(p_b_thread[bindex + 4]),
|
||||
"v"(p_b_thread[bindex + 5]),
|
||||
"v"(p_b_thread[bindex + 6]),
|
||||
"v"(p_b_thread[bindex + 7])
|
||||
"v"(p_b_thread[bindex + 7]),
|
||||
"0"(p_c_thread[cindex + 0]),
|
||||
"1"(p_c_thread[cindex + 1]),
|
||||
"2"(p_c_thread[cindex + 2]),
|
||||
|
||||
Reference in New Issue
Block a user