mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
add asm into lds_double_buffer version
This commit is contained in:
@@ -34,9 +34,8 @@ template <index_t GridSize,
|
||||
index_t WeiBlockCopyThreadPerDim1,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead>
|
||||
class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
|
||||
struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
|
||||
{
|
||||
public:
|
||||
__host__ __device__ constexpr index_t GetInputBlockElementSpace() const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
@@ -97,7 +96,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
|
||||
return wei_cyxk_block_desc.GetElementSpace(Number<max_align>{});
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t GetSharedMemoryUsage() const
|
||||
__host__ __device__ constexpr index_t GetDynamicSharedMemoryUsage() const
|
||||
{
|
||||
|
||||
return (GetInputBlockElementSpace() + GetWeightBlockElementSpace()) * sizeof(Float);
|
||||
@@ -300,22 +299,38 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
|
||||
__syncthreads())
|
||||
{
|
||||
// load data
|
||||
//blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
|
||||
//blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
|
||||
// load data
|
||||
#if 0
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
|
||||
#elif 0
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block);
|
||||
#elif 1
|
||||
Float4 tmp_in, tmp_wei;
|
||||
Float4* glb_in_p = (Float4 *)(p_in_global_block_offset + blockwise_in_copy.mSrcMyThreadOffset);
|
||||
Float4* loc_in_p = (Float4 *)(p_in_block + blockwise_in_copy.mDstMyThreadOffset);
|
||||
Float4* glb_in_p =
|
||||
(Float4*)(p_in_global_block_offset + blockwise_in_copy.mSrcMyThreadOffset);
|
||||
Float4* loc_in_p = (Float4*)(p_in_block + blockwise_in_copy.mDstMyThreadOffset);
|
||||
|
||||
Float4* glb_wei_p = (Float4 *)(p_wei_global_block_offset + blockwise_wei_copy.mSrcMyThreadOffset);
|
||||
Float4* loc_wei_p = (Float4 *)(p_wei_block + blockwise_wei_copy.mDstMyThreadOffset);
|
||||
Float4* glb_wei_p =
|
||||
(Float4*)(p_wei_global_block_offset + blockwise_wei_copy.mSrcMyThreadOffset);
|
||||
Float4* loc_wei_p = (Float4*)(p_wei_block + blockwise_wei_copy.mDstMyThreadOffset);
|
||||
|
||||
global_load(tmp_in, glb_in_p);
|
||||
global_load(tmp_wei, glb_wei_p);
|
||||
vmcnt(0);
|
||||
ds_write_b128(tmp_in, loc_in_p);
|
||||
ds_write_b128(tmp_wei, loc_wei_p);
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
@@ -34,9 +34,10 @@ template <index_t GridSize,
|
||||
index_t WeiBlockCopyThreadPerDim1,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead>
|
||||
class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
|
||||
struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
|
||||
{
|
||||
public:
|
||||
__host__ __device__ constexpr index_t GetDynamicSharedMemoryUsage() const { return 0; }
|
||||
|
||||
__global__ static void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global)
|
||||
@@ -239,9 +240,27 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
// preload data into LDS
|
||||
// preload data into LDS
|
||||
#if 0
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_0);
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_0);
|
||||
#else
|
||||
Float4 tmp_in, tmp_wei;
|
||||
Float4* glb_in_p =
|
||||
(Float4*)(p_in_global_block_offset + blockwise_in_copy.mSrcMyThreadOffset);
|
||||
Float4* glb_wei_p =
|
||||
(Float4*)(p_wei_global_block_offset + blockwise_wei_copy.mSrcMyThreadOffset);
|
||||
|
||||
global_load(tmp_in, glb_in_p);
|
||||
global_load(tmp_wei, glb_wei_p);
|
||||
|
||||
Float4* loc_in_p = (Float4*)(p_in_block_0 + blockwise_in_copy.mDstMyThreadOffset);
|
||||
Float4* loc_wei_p = (Float4*)(p_wei_block_0 + blockwise_wei_copy.mDstMyThreadOffset);
|
||||
|
||||
vmcnt(0);
|
||||
ds_write_b128(tmp_in, loc_in_p);
|
||||
ds_write_b128(tmp_wei, loc_wei_p);
|
||||
#endif
|
||||
|
||||
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0);
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0);
|
||||
@@ -270,9 +289,6 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
|
||||
|
||||
// load next data
|
||||
#if 0
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next);
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next);
|
||||
#elif 1
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
@@ -281,6 +297,15 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
|
||||
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
#elif 1
|
||||
Float4 tmp_in, tmp_wei;
|
||||
Float4* glb_in_p =
|
||||
(Float4*)(p_in_global_block_offset + blockwise_in_copy.mSrcMyThreadOffset);
|
||||
Float4* glb_wei_p =
|
||||
(Float4*)(p_wei_global_block_offset + blockwise_wei_copy.mSrcMyThreadOffset);
|
||||
|
||||
global_load(tmp_in, glb_in_p);
|
||||
global_load(tmp_wei, glb_wei_p);
|
||||
#endif
|
||||
|
||||
// compute on current data
|
||||
@@ -290,22 +315,31 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
#if 1
|
||||
#if 0
|
||||
blockwise_gemm.Run
|
||||
#else
|
||||
#elif 0
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#elif 1
|
||||
blockwise_gemm.Run_asm
|
||||
#endif
|
||||
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_now + y * Wi + x,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_now + y * Wi + x,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_next);
|
||||
#elif 1
|
||||
Float4* loc_in_p = (Float4*)(p_in_block_next + blockwise_in_copy.mDstMyThreadOffset);
|
||||
Float4* loc_wei_p = (Float4*)(p_wei_block_next + blockwise_wei_copy.mDstMyThreadOffset);
|
||||
|
||||
vmcnt(0);
|
||||
ds_write_b128(tmp_in, loc_in_p);
|
||||
ds_write_b128(tmp_wei, loc_wei_p);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -321,15 +355,17 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
#if 1
|
||||
#if 0
|
||||
blockwise_gemm.Run
|
||||
#else
|
||||
#elif 1
|
||||
blockwise_gemm.Run_asm
|
||||
#elif 0
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#endif
|
||||
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_now + y * Wi + x,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_now + y * Wi + x,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,28 +4,34 @@ typedef float Float4 __attribute__((ext_vector_type(4)));
|
||||
|
||||
extern "C" __attribute__((address_space(3))) void* __to_local(void* p)[[hc]];
|
||||
|
||||
inline __device__ void vmcnt(int cnt) {
|
||||
if(cnt == 0) {
|
||||
asm volatile ("\n \
|
||||
inline __device__ void vmcnt(int cnt)
|
||||
{
|
||||
if(cnt == 0)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt vmcnt(0) \n \
|
||||
"::);
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 1) {
|
||||
asm volatile ("\n \
|
||||
else if(cnt == 1)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt vmcnt(1) \n \
|
||||
"::);
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 2) {
|
||||
asm volatile ("\n \
|
||||
else if(cnt == 2)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt vmcnt(2) \n \
|
||||
"::);
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 4) {
|
||||
asm volatile ("\n \
|
||||
else if(cnt == 4)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt vmcnt(2) \n \
|
||||
"::);
|
||||
" ::);
|
||||
}
|
||||
else {
|
||||
else
|
||||
{
|
||||
assert(0);
|
||||
}
|
||||
}
|
||||
@@ -397,13 +403,13 @@ inline __device__ void ds_read_b128(Float4& r, void* lds, int offset = 0)
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void global_load(Float4 &r, Float4* ptr) {
|
||||
asm volatile("\n \
|
||||
inline __device__ void global_load(Float4& r, Float4* ptr)
|
||||
{
|
||||
asm volatile("\n \
|
||||
global_load_dwordx4 %0, %1, off \n \
|
||||
"
|
||||
:"=v"(r)
|
||||
:"v"(ptr)
|
||||
);
|
||||
: "=v"(r)
|
||||
: "v"(ptr));
|
||||
}
|
||||
|
||||
inline __device__ void ds_write_b128(Float4& r, void* lds, int offset = 0)
|
||||
@@ -411,8 +417,6 @@ inline __device__ void ds_write_b128(Float4& r, void* lds, int offset = 0)
|
||||
asm volatile("\n \
|
||||
ds_write_b128 %0, %1 \n \
|
||||
"
|
||||
:
|
||||
: "v"(__to_local(lds)), "v"(r)
|
||||
);
|
||||
:
|
||||
: "v"(__to_local(lds)), "v"(r));
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user