mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
clean code
This commit is contained in:
@@ -361,10 +361,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
|
||||
// thread A-sub, B-sub for copy
|
||||
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
|
||||
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
|
||||
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
|
||||
|
||||
float p_thread[a_thread_mtx.GetElementSpace() + b_thread_mtx.GetElementSpace()];
|
||||
|
||||
@@ -377,66 +377,42 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
// auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA;
|
||||
// auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB;
|
||||
Float4* reg_a = (Float4*)(p_a_thread);
|
||||
Float4* reg_b = (Float4*)(p_b_thread);
|
||||
Float4* reg_c = (Float4*)(p_c_thread);
|
||||
void* a_loc = (void*)(p_a_block + mMyThreadOffsetA);
|
||||
void* b_loc = (void*)(p_b_block + mMyThreadOffsetB);
|
||||
// loop over k
|
||||
int k_chunk = K;
|
||||
//for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop * k_chunk)
|
||||
index_t k_begin = 0;
|
||||
|
||||
int lds_a_block_off = sizeof(Float) * M;
|
||||
int lds_b_block_off = sizeof(Float) * N;
|
||||
int lds_a_block_off_1 = MPerLevel1Cluster * sizeof(Float);
|
||||
int lds_b_block_off_1 = NPerLevel1Cluster * sizeof(Float);
|
||||
ds_read_b128(reg_a[0], a_loc, 0);
|
||||
ds_read_b128(reg_b[0], b_loc, 0);
|
||||
ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1);
|
||||
ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1);
|
||||
lgkmcnt(2);
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
lgkmcnt(1);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
lgkmcnt(0);
|
||||
#pragma unroll
|
||||
for(int k_i = 1; k_i < K; k_i++)
|
||||
{
|
||||
|
||||
#if 0
|
||||
ds_read_b128(reg_a[0], a_loc, 0);
|
||||
ds_read_b128(reg_a[1], a_loc, 256);
|
||||
ds_read_b128(reg_b[0], b_loc, 0);
|
||||
ds_read_b128(reg_b[1], b_loc, 128);
|
||||
|
||||
lgkmcnt(0);
|
||||
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
ds_read_b128(reg_a[0], a_loc, k_i * lds_a_block_off);
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
ds_read_b128(reg_b[0], b_loc, k_i * lds_b_block_off);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
#else
|
||||
int k = k_begin;
|
||||
int lds_a_block_off = sizeof(Float) * M;
|
||||
int lds_b_block_off = sizeof(Float) * N;
|
||||
int lds_a_block_off_1 = MPerLevel1Cluster * sizeof(Float);
|
||||
int lds_b_block_off_1 = NPerLevel1Cluster * sizeof(Float);
|
||||
ds_read_b128(reg_a[0], a_loc, k * lds_a_block_off);
|
||||
ds_read_b128(reg_b[0], b_loc, k * lds_b_block_off);
|
||||
ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k * lds_b_block_off);
|
||||
ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k * lds_a_block_off);
|
||||
ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k_i * lds_b_block_off);
|
||||
ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k_i * lds_a_block_off);
|
||||
lgkmcnt(2);
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
lgkmcnt(1);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
lgkmcnt(0);
|
||||
#pragma unroll
|
||||
for(int i = 0; i < k_chunk - 1; i++)
|
||||
{
|
||||
k = k + 1;
|
||||
ds_read_b128(reg_a[0], a_loc, k * lds_a_block_off);
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
ds_read_b128(reg_b[0], b_loc, k * lds_b_block_off);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k * lds_b_block_off);
|
||||
ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k * lds_a_block_off);
|
||||
lgkmcnt(2);
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
lgkmcnt(1);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
lgkmcnt(0);
|
||||
}
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
#endif
|
||||
}
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
}
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC, class Accumulator>
|
||||
|
||||
Reference in New Issue
Block a user