mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
enabled ds_read_b128 and ds_write_b128 on hip c++
This commit is contained in:
@@ -132,10 +132,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
const FloatB* __restrict__ p_b_block,
|
||||
FloatC* __restrict__ p_c_thread) const
|
||||
{
|
||||
static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
|
||||
is_same<FloatC, float>::value,
|
||||
"Run_asm only deal with float\n");
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
constexpr auto False = integral_constant<bool, false>{};
|
||||
|
||||
@@ -164,56 +160,48 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
// assertion for inline asm
|
||||
static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
|
||||
is_same<FloatC, float>::value,
|
||||
"Run_asm only deal with float\n");
|
||||
|
||||
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
|
||||
MPerThread == 8 && NPerThread == 8,
|
||||
"Run_asm cannot deal with this GEMM shape yet\n");
|
||||
|
||||
using Float4 = vector_type<float, 4>::MemoryType;
|
||||
|
||||
float p_thread[a_thread_mtx.GetElementSpace() + b_thread_mtx.GetElementSpace()];
|
||||
|
||||
FloatA* p_a_thread = p_thread;
|
||||
FloatB* p_b_thread = p_thread + a_thread_mtx.GetElementSpace();
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
Float4* reg_a = (Float4*)(p_a_thread);
|
||||
Float4* reg_b = (Float4*)(p_b_thread);
|
||||
Float4* reg_c = (Float4*)(p_c_thread);
|
||||
void* a_loc = (void*)(p_a_block + mMyThreadOffsetA);
|
||||
void* b_loc = (void*)(p_b_block + mMyThreadOffsetB);
|
||||
|
||||
int lds_a_block_off = sizeof(Float) * M;
|
||||
int lds_b_block_off = sizeof(Float) * N;
|
||||
int lds_a_block_off_1 = MPerLevel1Cluster * sizeof(Float);
|
||||
int lds_b_block_off_1 = NPerLevel1Cluster * sizeof(Float);
|
||||
ds_read_b128(reg_a[0], a_loc, 0);
|
||||
ds_read_b128(reg_b[0], b_loc, 0);
|
||||
ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1);
|
||||
ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1);
|
||||
lgkmcnt(2);
|
||||
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
|
||||
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
|
||||
reg_b[1] =
|
||||
*reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
|
||||
reg_a[1] =
|
||||
*reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
lgkmcnt(1);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
lgkmcnt(0);
|
||||
#pragma unroll
|
||||
for(int k_i = 1; k_i < K; k_i++)
|
||||
for(index_t k = 1; k < K; ++k)
|
||||
{
|
||||
ds_read_b128(reg_a[0], a_loc, k_i * lds_a_block_off);
|
||||
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]);
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
ds_read_b128(reg_b[0], b_loc, k_i * lds_b_block_off);
|
||||
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k_i * lds_b_block_off);
|
||||
ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k_i * lds_a_block_off);
|
||||
lgkmcnt(2);
|
||||
reg_b[1] = *reinterpret_cast<const Float4*>(
|
||||
&p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
|
||||
reg_a[1] = *reinterpret_cast<const Float4*>(
|
||||
&p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
lgkmcnt(1);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
lgkmcnt(0);
|
||||
}
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
|
||||
@@ -213,17 +213,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
#if 1
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_double);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_double);
|
||||
#else
|
||||
vmcnt(0);
|
||||
blockwise_in_copy.RunStoreRegisterClipboard_asm(p_in_register_clipboard,
|
||||
p_in_block_double);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard_asm(p_wei_register_clipboard,
|
||||
p_wei_block_double);
|
||||
#endif
|
||||
}
|
||||
|
||||
// register
|
||||
@@ -261,7 +253,6 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
@@ -271,31 +262,23 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
#if 1
|
||||
#if 0
|
||||
blockwise_gemm.Run
|
||||
#elif 0
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#elif 0
|
||||
#elif 1
|
||||
blockwise_gemm.Run_asm
|
||||
#endif
|
||||
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_now + y * Wi + x,
|
||||
p_out_thread);
|
||||
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_now + y * Wi + x,
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
|
||||
#if 1
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
|
||||
p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_next);
|
||||
#else
|
||||
vmcnt(0);
|
||||
blockwise_in_copy.RunStoreRegisterClipboard_asm(p_in_register_clipboard,
|
||||
p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard_asm(p_wei_register_clipboard,
|
||||
p_wei_block_next);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -320,32 +303,23 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
#if 1
|
||||
#if 0
|
||||
blockwise_gemm.Run
|
||||
#elif 0
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#elif 0
|
||||
#elif 1
|
||||
blockwise_gemm.Run_asm
|
||||
#endif
|
||||
(p_wei_block_double + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_double + y * Wi + x,
|
||||
p_out_thread);
|
||||
(p_wei_block_double + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_double + y * Wi + x,
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
|
||||
#if 1
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
|
||||
p_in_block_double + in_block_space);
|
||||
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_double + wei_block_space);
|
||||
#else
|
||||
vmcnt(0);
|
||||
blockwise_in_copy.RunStoreRegisterClipboard_asm(p_in_register_clipboard,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard_asm(p_wei_register_clipboard,
|
||||
p_wei_block_double + wei_block_space);
|
||||
#endif
|
||||
|
||||
// odd
|
||||
__syncthreads();
|
||||
@@ -354,17 +328,17 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
#if 1
|
||||
#if 0
|
||||
blockwise_gemm.Run
|
||||
#elif 0
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#elif 0
|
||||
#elif 1
|
||||
blockwise_gemm.Run_asm
|
||||
#endif
|
||||
(p_wei_block_double + wei_block_space +
|
||||
wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_double + in_block_space + y * Wi + x,
|
||||
p_out_thread);
|
||||
(p_wei_block_double + wei_block_space +
|
||||
wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_double + in_block_space + y * Wi + x,
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user