mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
make LDS double buffer works, 1x1 conv now hits 80% of peak
This commit is contained in:
@@ -614,7 +614,7 @@ int main()
|
||||
nrepeat);
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
if(S == 3 && R == 3)
|
||||
{
|
||||
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
|
||||
|
||||
@@ -128,7 +128,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
|
||||
|
||||
constexpr unsigned BlockSize = 64;
|
||||
#elif 1
|
||||
// 1x1, 28x28, 128 threads
|
||||
// 1x1, 28x28, 128 threads, no lds-double-buffer
|
||||
// 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128
|
||||
constexpr unsigned BPerBlock = 64;
|
||||
constexpr unsigned KPerBlock = 128;
|
||||
constexpr unsigned CPerBlock = 8;
|
||||
@@ -215,37 +216,37 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
|
||||
cudaEventCreate(&start);
|
||||
cudaEventRecord(start, 0);
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw
|
||||
#else
|
||||
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer
|
||||
#endif
|
||||
<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_cnhw_desc),
|
||||
decltype(wei_csrk_desc),
|
||||
decltype(out_knhw_desc),
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
BPerThread,
|
||||
KPerThread,
|
||||
GemmThreadPerColumnPerCluster,
|
||||
GemmThreadPerRowPerCluster,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1,
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1,
|
||||
InBlockCopyDataPerRead,
|
||||
WeiBlockCopyDataPerRead>
|
||||
<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_cnhw_desc),
|
||||
decltype(wei_csrk_desc),
|
||||
decltype(out_knhw_desc),
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
BPerThread,
|
||||
KPerThread,
|
||||
GemmThreadPerColumnPerCluster,
|
||||
GemmThreadPerRowPerCluster,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1,
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1,
|
||||
InBlockCopyDataPerRead,
|
||||
WeiBlockCopyDataPerRead>
|
||||
<<<grid_dim, block_dim>>>(in_cnhw_desc,
|
||||
static_cast<T*>(in_cnhw_device_buf.GetDeviceBuffer()),
|
||||
wei_csrk_desc,
|
||||
|
||||
@@ -512,4 +512,196 @@ struct Blockwise2dTensorCopy3
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if 1
|
||||
__device__ constexpr unsigned GetRegisterClipboardSize() const
|
||||
{
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr unsigned L0 = CopyLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0;
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
|
||||
Float* p_clipboard) const
|
||||
{
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
|
||||
|
||||
using Float2 = float2;
|
||||
using Float4 = float4;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr unsigned L0 = CopyLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr unsigned nloop_d0 = L0 / thread_per_d0;
|
||||
|
||||
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
|
||||
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
|
||||
{
|
||||
if(DataPerRead == 1)
|
||||
{
|
||||
p_clipboard[iloop] = p_src[mSrcMyThreadOffset + iloop * src_loop_stride];
|
||||
}
|
||||
else if(DataPerRead == 2)
|
||||
{
|
||||
*(reinterpret_cast<Float2*>(p_clipboard + iloop * 2)) =
|
||||
*(reinterpret_cast<const Float2*>(p_src + mSrcMyThreadOffset +
|
||||
iloop * src_loop_stride));
|
||||
}
|
||||
else if(DataPerRead == 4)
|
||||
{
|
||||
*(reinterpret_cast<Float4*>(p_clipboard + iloop * 4)) =
|
||||
*(reinterpret_cast<const Float4*>(p_src + mSrcMyThreadOffset +
|
||||
iloop * src_loop_stride));
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
|
||||
|
||||
if(has_tail_d0)
|
||||
{
|
||||
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
|
||||
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
|
||||
{
|
||||
if(DataPerRead == 1)
|
||||
{
|
||||
p_clipboard[nloop_d0] = p_src[mSrcMyThreadOffset + nloop_d0 * src_loop_stride];
|
||||
}
|
||||
else if(DataPerRead == 2)
|
||||
{
|
||||
*(reinterpret_cast<Float2*>(p_clipboard + nloop_d0 * 2)) =
|
||||
*(reinterpret_cast<const Float2*>(p_src + mSrcMyThreadOffset +
|
||||
nloop_d0 * src_loop_stride));
|
||||
}
|
||||
else if(DataPerRead == 4)
|
||||
{
|
||||
*(reinterpret_cast<Float4*>(p_clipboard + nloop_d0 * 4)) =
|
||||
*(reinterpret_cast<const Float4*>(p_src + mSrcMyThreadOffset +
|
||||
nloop_d0 * src_loop_stride));
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
|
||||
Float* __restrict__ p_dst) const
|
||||
{
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
|
||||
|
||||
using Float2 = float2;
|
||||
using Float4 = float4;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr unsigned L0 = CopyLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr unsigned nloop_d0 = L0 / thread_per_d0;
|
||||
|
||||
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
|
||||
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
|
||||
{
|
||||
if(DataPerRead == 1)
|
||||
{
|
||||
p_dst[mDstMyThreadOffset + iloop * dst_loop_stride] = p_clipboard[iloop];
|
||||
}
|
||||
else if(DataPerRead == 2)
|
||||
{
|
||||
*(reinterpret_cast<Float2*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
|
||||
*(reinterpret_cast<const Float2*>(p_clipboard + iloop * 2));
|
||||
}
|
||||
else if(DataPerRead == 4)
|
||||
{
|
||||
*(reinterpret_cast<Float4*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
|
||||
*(reinterpret_cast<const Float4*>(p_clipboard + iloop * 4));
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
|
||||
|
||||
if(has_tail_d0)
|
||||
{
|
||||
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
|
||||
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
|
||||
{
|
||||
if(DataPerRead == 1)
|
||||
{
|
||||
p_dst[mDstMyThreadOffset + nloop_d0 * dst_loop_stride] = p_clipboard[nloop_d0];
|
||||
}
|
||||
else if(DataPerRead == 2)
|
||||
{
|
||||
*(reinterpret_cast<Float2*>(p_dst + mDstMyThreadOffset +
|
||||
nloop_d0 * dst_loop_stride)) =
|
||||
*(reinterpret_cast<const Float2*>(p_clipboard + nloop_d0 * 2));
|
||||
}
|
||||
else if(DataPerRead == 4)
|
||||
{
|
||||
*(reinterpret_cast<Float4*>(p_dst + mDstMyThreadOffset +
|
||||
nloop_d0 * dst_loop_stride)) =
|
||||
*(reinterpret_cast<const Float4*>(p_clipboard + nloop_d0 * 4));
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
@@ -262,8 +262,26 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b
|
||||
__syncthreads();
|
||||
|
||||
// load next data
|
||||
#if 0
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next);
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next);
|
||||
#elif 0
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next);
|
||||
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
#elif 1
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
#endif
|
||||
|
||||
// compute on current data
|
||||
// a series of GEMM
|
||||
@@ -283,6 +301,13 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next);
|
||||
#elif 1
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next);
|
||||
#endif
|
||||
}
|
||||
|
||||
// last computation
|
||||
|
||||
Reference in New Issue
Block a user