mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
unroll even-odd loop
This commit is contained in:
@@ -311,7 +311,10 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
|
||||
static_cast<T*>(wei_cyxk_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_khwn_device_buf.GetDeviceBuffer()));
|
||||
|
||||
printf("Elapsed time : %f ms\n", time);
|
||||
printf("Elapsed time : %f ms, %f TFlop/s\n",
|
||||
time,
|
||||
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1024) * 1024 * 1024 * 1024) / (time / 1000));
|
||||
usleep(std::min(time * 1000, float(10000)));
|
||||
}
|
||||
|
||||
|
||||
@@ -386,7 +386,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
void* b_loc = (void*)(p_b_block + mMyThreadOffsetB);
|
||||
// loop over k
|
||||
int k_chunk = K;
|
||||
//for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop * k_chunk)
|
||||
// for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop * k_chunk)
|
||||
index_t k_begin = 0;
|
||||
{
|
||||
|
||||
|
||||
@@ -69,3 +69,27 @@ __host__ __device__ constexpr auto get_convolution_with_padding_output_default_4
|
||||
|
||||
return make_ConstantTensorDescriptor(Sequence<N, K, HO, WO>{});
|
||||
}
|
||||
|
||||
template <class InDesc, class WeiDesc, class OutDesc>
|
||||
__host__ __device__ constexpr std::size_t calculate_convolution_flops(InDesc, WeiDesc, OutDesc)
|
||||
{
|
||||
constexpr auto in_desc = InDesc{};
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
constexpr auto out_desc = OutDesc{};
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t N = out_desc.GetLength(I0);
|
||||
constexpr index_t K = out_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t C = wei_desc.GetLength(I1);
|
||||
constexpr index_t Y = wei_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_desc.GetLength(I3);
|
||||
|
||||
return std::size_t(2) * N * K * Ho * Wo * C * Y * X;
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t max_align =
|
||||
mod_conv::max(InBlockCopyDataPerRead, WeiBlockCopyDataPerRead);
|
||||
mod_conv::max(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead);
|
||||
|
||||
return in_cb_block_desc.GetElementSpace(Number<max_align>{});
|
||||
}
|
||||
|
||||
@@ -213,26 +213,18 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
|
||||
#endif
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_element_size =
|
||||
in_cb_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
|
||||
constexpr index_t max_align =
|
||||
mod_conv::max(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead);
|
||||
|
||||
constexpr index_t wei_block_element_size =
|
||||
wei_cyxk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
|
||||
constexpr index_t in_block_element_space =
|
||||
in_cb_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
|
||||
? InBlockCopyDataPerRead
|
||||
: WeiBlockCopyDataPerRead;
|
||||
constexpr index_t wei_block_element_space =
|
||||
wei_cyxk_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
// LDS double buffer
|
||||
__shared__ Float
|
||||
p_in_block_0[max_align * ((in_block_element_size + max_align - 1) / max_align)];
|
||||
__shared__ Float
|
||||
p_wei_block_0[max_align * ((wei_block_element_size + max_align - 1) / max_align)];
|
||||
|
||||
__shared__ Float
|
||||
p_in_block_1[max_align * ((in_block_element_size + max_align - 1) / max_align)];
|
||||
__shared__ Float
|
||||
p_wei_block_1[max_align * ((wei_block_element_size + max_align - 1) / max_align)];
|
||||
__shared__ Float p_in_block_double[2 * in_block_element_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_element_space];
|
||||
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin);
|
||||
@@ -254,62 +246,122 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
|
||||
global_load(tmp_in, glb_in_p);
|
||||
global_load(tmp_wei, glb_wei_p);
|
||||
|
||||
Float4* loc_in_p = (Float4*)(p_in_block_0 + blockwise_in_copy.mDstMyThreadOffset);
|
||||
Float4* loc_wei_p = (Float4*)(p_wei_block_0 + blockwise_wei_copy.mDstMyThreadOffset);
|
||||
Float4* loc_in_p = (Float4*)(p_in_block_double + blockwise_in_copy.mDstMyThreadOffset);
|
||||
Float4* loc_wei_p = (Float4*)(p_wei_block_double + blockwise_wei_copy.mDstMyThreadOffset);
|
||||
|
||||
vmcnt(0);
|
||||
ds_write_b128(tmp_in, loc_in_p);
|
||||
ds_write_b128(tmp_wei, loc_wei_p);
|
||||
#endif
|
||||
|
||||
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0);
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0);
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_kb_thread_desc.GetElementSpace()];
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread);
|
||||
|
||||
bool even_loop = true;
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin + CPerBlock < C;
|
||||
c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
|
||||
even_loop = !even_loop)
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C;
|
||||
c_block_data_begin += 2 * CPerBlock)
|
||||
{
|
||||
Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1;
|
||||
Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1;
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_in_block_next = even_loop ? p_in_block_1 : p_in_block_0;
|
||||
Float* p_wei_block_next = even_loop ? p_wei_block_1 : p_wei_block_0;
|
||||
Float* p_in_block_now =
|
||||
even_loop ? p_in_block_double : p_in_block_double + in_block_element_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_element_space;
|
||||
|
||||
__syncthreads();
|
||||
Float* p_in_block_next =
|
||||
even_loop ? p_in_block_double + in_block_element_space : p_in_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_element_space : p_wei_block_double;
|
||||
|
||||
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0);
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0);
|
||||
|
||||
// load next data
|
||||
#if 0
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
#elif 1
|
||||
Float4 tmp_in, tmp_wei;
|
||||
Float4* glb_in_p =
|
||||
(Float4*)(p_in_global_block_offset + blockwise_in_copy.mSrcMyThreadOffset);
|
||||
Float4* glb_wei_p =
|
||||
(Float4*)(p_wei_global_block_offset + blockwise_wei_copy.mSrcMyThreadOffset);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
global_load(tmp_in, glb_in_p);
|
||||
global_load(tmp_wei, glb_wei_p);
|
||||
#endif
|
||||
|
||||
// compute on current data
|
||||
// a series of GEMM
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
#if 0
|
||||
blockwise_gemm.Run
|
||||
#elif 0
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#elif 1
|
||||
blockwise_gemm.Run_asm
|
||||
#endif
|
||||
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_now + y * Wi + x,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_next);
|
||||
#elif 1
|
||||
Float4* loc_in_p =
|
||||
(Float4*)(p_in_block_next + blockwise_in_copy.mDstMyThreadOffset);
|
||||
Float4* loc_wei_p =
|
||||
(Float4*)(p_wei_block_next + blockwise_wei_copy.mDstMyThreadOffset);
|
||||
|
||||
vmcnt(0);
|
||||
ds_write_b128(tmp_in, loc_in_p);
|
||||
ds_write_b128(tmp_wei, loc_wei_p);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
if(C % 2 == 0)
|
||||
{
|
||||
// even
|
||||
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0);
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0);
|
||||
|
||||
Float4 tmp_in, tmp_wei;
|
||||
Float4* glb_in_p =
|
||||
(Float4*)(p_in_global_block_offset + blockwise_in_copy.mSrcMyThreadOffset);
|
||||
Float4* glb_wei_p =
|
||||
(Float4*)(p_wei_global_block_offset + blockwise_wei_copy.mSrcMyThreadOffset);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
global_load(tmp_in, glb_in_p);
|
||||
global_load(tmp_wei, glb_wei_p);
|
||||
#endif
|
||||
|
||||
// compute on current data
|
||||
// a series of GEMM
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
@@ -317,37 +369,28 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
#if 0
|
||||
blockwise_gemm.Run
|
||||
#elif 0
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#elif 1
|
||||
blockwise_gemm.Run_asm
|
||||
#elif 0
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#endif
|
||||
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_now + y * Wi + x,
|
||||
(p_wei_block_double + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_double + y * Wi + x,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_next);
|
||||
#elif 1
|
||||
Float4* loc_in_p = (Float4*)(p_in_block_next + blockwise_in_copy.mDstMyThreadOffset);
|
||||
Float4* loc_wei_p = (Float4*)(p_wei_block_next + blockwise_wei_copy.mDstMyThreadOffset);
|
||||
Float4* loc_in_p = (Float4*)(p_in_block_double + in_block_element_space +
|
||||
blockwise_in_copy.mDstMyThreadOffset);
|
||||
Float4* loc_wei_p = (Float4*)(p_wei_block_double + wei_block_element_space +
|
||||
blockwise_wei_copy.mDstMyThreadOffset);
|
||||
|
||||
vmcnt(0);
|
||||
ds_write_b128(tmp_in, loc_in_p);
|
||||
ds_write_b128(tmp_wei, loc_wei_p);
|
||||
#endif
|
||||
}
|
||||
|
||||
// last computation
|
||||
{
|
||||
Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1;
|
||||
Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1;
|
||||
|
||||
// odd
|
||||
__syncthreads();
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
@@ -362,13 +405,19 @@ struct gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
|
||||
#elif 0
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#endif
|
||||
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_now + y * Wi + x,
|
||||
(p_wei_block_double + in_block_element_space +
|
||||
wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_double + wei_block_element_space + y * Wi + x,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// not implemented
|
||||
assert(false);
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
|
||||
Reference in New Issue
Block a user