mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
another version of blockwise 2d tensor copy
This commit is contained in:
@@ -376,7 +376,7 @@ int main()
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
#elif 0
|
||||
#elif 1
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 36;
|
||||
@@ -427,7 +427,7 @@ int main()
|
||||
#endif
|
||||
(in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat);
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host);
|
||||
check_error(out_nkhw_host, out_nkhw_device);
|
||||
#elif 0
|
||||
|
||||
@@ -103,19 +103,6 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 0
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 32;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned KPerThread = 4;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#endif
|
||||
|
||||
|
||||
@@ -75,10 +75,23 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
|
||||
constexpr unsigned KPerThread = 1;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
|
||||
constexpr unsigned ThreadPerClusterRow = 1;
|
||||
constexpr unsigned ThreadPerClusterColumn = 4;
|
||||
constexpr unsigned GemmThreadPerClusterRow = 1;
|
||||
constexpr unsigned GemmThreadPerClusterColumn = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 32;
|
||||
#elif 0
|
||||
constexpr unsigned BPerBlock = 128;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
|
||||
constexpr unsigned BPerThread = 8;
|
||||
constexpr unsigned KPerThread = 8;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
|
||||
constexpr unsigned GemmThreadPerClusterRow = 4;
|
||||
constexpr unsigned GemmThreadPerClusterColumn = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 1
|
||||
constexpr unsigned BPerBlock = 128;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
@@ -88,8 +101,11 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
|
||||
constexpr unsigned KPerThread = 8;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
|
||||
constexpr unsigned ThreadPerClusterRow = 4;
|
||||
constexpr unsigned ThreadPerClusterColumn = 4;
|
||||
constexpr unsigned GemmThreadPerClusterRow = 4;
|
||||
constexpr unsigned GemmThreadPerClusterColumn = 4;
|
||||
|
||||
constexpr unsigned InBlockCopyThreadPerDim0 = 2;
|
||||
constexpr unsigned InBlockCopyThreadPerDim1 = 64;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#endif
|
||||
@@ -132,8 +148,10 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
|
||||
BPerThread,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
ThreadPerClusterRow,
|
||||
ThreadPerClusterColumn>
|
||||
GemmThreadPerClusterRow,
|
||||
GemmThreadPerClusterColumn,
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1>
|
||||
<<<grid_dim, block_dim>>>(in_cnhw_desc,
|
||||
static_cast<T*>(in_cnhw_device_buf.GetDeviceBuffer()),
|
||||
wei_srck_desc,
|
||||
|
||||
@@ -162,11 +162,188 @@ blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
}
|
||||
|
||||
template <unsigned BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
__device__ void blockwise_2d_tensor_copy(
|
||||
SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths)
|
||||
struct blockwise_2d_tensor_copy_1
|
||||
{
|
||||
constexpr auto dst_from_src_reorder = Sequence<0, 1>{};
|
||||
__device__ void run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto dst_from_src_reorder = Sequence<0, 1>{};
|
||||
|
||||
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
|
||||
}
|
||||
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
|
||||
}
|
||||
};
|
||||
|
||||
template <unsigned BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
unsigned ThreadPerDim0,
|
||||
unsigned ThreadPerDim1>
|
||||
struct blockwise_2d_tensor_copy_2
|
||||
{
|
||||
unsigned mThreadId0;
|
||||
unsigned mThreadId1;
|
||||
|
||||
__device__ blockwise_2d_tensor_copy_2()
|
||||
{
|
||||
mThreadId0 = get_thread_local_1d_id() / ThreadPerDim1;
|
||||
mThreadId1 = get_thread_local_1d_id() - mThreadId0 * ThreadPerDim1;
|
||||
}
|
||||
|
||||
__device__ void run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
if(get_thread_local_1d_id() >= ThreadPerDim0 * ThreadPerDim1)
|
||||
return;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr unsigned L0 = SrcOpLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = SrcOpLengths{}.Get(I1);
|
||||
|
||||
constexpr unsigned Dim0Loop = L0 / ThreadPerDim0;
|
||||
constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop);
|
||||
|
||||
constexpr unsigned Dim1V4Loop = L1 / (ThreadPerDim1 * 4);
|
||||
constexpr unsigned Dim1V2Loop =
|
||||
(L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2);
|
||||
constexpr unsigned Dim1V1Loop =
|
||||
(L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) /
|
||||
ThreadPerDim1;
|
||||
constexpr bool d1_has_tail =
|
||||
(L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop));
|
||||
|
||||
for(unsigned d0loop = 0; d0loop < Dim0Loop; ++d0loop)
|
||||
{
|
||||
unsigned did0 = d0loop * ThreadPerDim0 + mThreadId0;
|
||||
|
||||
// v4
|
||||
for(unsigned d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
|
||||
{
|
||||
unsigned did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
|
||||
|
||||
for(unsigned i = 0; i < 4; ++i)
|
||||
{
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1 + i);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1 + i);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
}
|
||||
|
||||
// v2
|
||||
for(unsigned d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
|
||||
{
|
||||
unsigned did1 =
|
||||
Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1;
|
||||
|
||||
for(unsigned i = 0; i < 2; ++i)
|
||||
{
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1 + i);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1 + i);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
}
|
||||
|
||||
// v1
|
||||
for(unsigned d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
|
||||
{
|
||||
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
|
||||
d1v1loop * ThreadPerDim1 + mThreadId1;
|
||||
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
|
||||
// dim-1 tail
|
||||
if(d1_has_tail)
|
||||
{
|
||||
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
|
||||
Dim1V1Loop * ThreadPerDim1 + mThreadId1;
|
||||
|
||||
if(did1 < L1)
|
||||
{
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dim-0 tail
|
||||
if(d0_has_tail)
|
||||
{
|
||||
unsigned did0 = Dim0Loop * ThreadPerDim0 + mThreadId0;
|
||||
|
||||
if(did0 < L0)
|
||||
{
|
||||
|
||||
// v4
|
||||
for(unsigned d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
|
||||
{
|
||||
unsigned did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
|
||||
|
||||
for(unsigned i = 0; i < 4; ++i)
|
||||
{
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1 + i);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1 + i);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
}
|
||||
|
||||
// v2
|
||||
for(unsigned d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
|
||||
{
|
||||
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
|
||||
2 * mThreadId1;
|
||||
|
||||
for(unsigned i = 0; i < 2; ++i)
|
||||
{
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1 + i);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1 + i);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
}
|
||||
|
||||
// v1
|
||||
for(unsigned d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
|
||||
{
|
||||
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 +
|
||||
Dim1V2Loop * 2 * ThreadPerDim1 + d1v1loop * ThreadPerDim1 +
|
||||
mThreadId1;
|
||||
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
|
||||
// tail
|
||||
if(d1_has_tail)
|
||||
{
|
||||
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 +
|
||||
Dim1V2Loop * 2 * ThreadPerDim1 + Dim1V1Loop * ThreadPerDim1 +
|
||||
mThreadId1;
|
||||
|
||||
if(did1 < L1)
|
||||
{
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#define WARPSIZE 32;
|
||||
|
||||
template <class T1, class T2>
|
||||
struct is_same
|
||||
{
|
||||
|
||||
@@ -153,6 +153,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
|
||||
for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1);
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
#if 1
|
||||
// input: global mem to LDS,
|
||||
// convert [N,C,Hi,Wi] to [C,Hi,Wi,N]
|
||||
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
|
||||
@@ -165,7 +166,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
|
||||
p_in_block,
|
||||
in_nchw_block_desc.GetLengths(),
|
||||
reorder_chwn_from_nchw);
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
// weight: global mem to LDS,
|
||||
// format is [S,R,C,K], no conversion needed
|
||||
blockwise_4d_tensor_copy<BlockSize>(
|
||||
@@ -175,6 +178,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
|
||||
wei_srck_block_desc,
|
||||
p_wei_block,
|
||||
wei_srck_block_desc.GetLengths());
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
@@ -20,8 +20,10 @@ template <unsigned GridSize,
|
||||
unsigned BPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread,
|
||||
unsigned ThreadPerClusterRow,
|
||||
unsigned ThreadPerClusterColumn>
|
||||
unsigned GemmThreadPerClusterRow,
|
||||
unsigned GemmThreadPerClusterColumn,
|
||||
unsigned InBlockCopyThreadPerDim0,
|
||||
unsigned InBlockCopyThreadPerDim1>
|
||||
__global__ void
|
||||
gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
Float* const __restrict__ p_in_global,
|
||||
@@ -104,6 +106,26 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
// blockwise 2d copy
|
||||
const auto blockwise_2d_copy =
|
||||
blockwise_2d_tensor_copy_1<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths())>{};
|
||||
#elif 0
|
||||
// blockwise 2d copy
|
||||
const auto blockwise_2d_copy =
|
||||
blockwise_2d_tensor_copy_2<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1>{};
|
||||
#endif
|
||||
|
||||
// a series of blockwise GEMM
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
|
||||
@@ -130,8 +152,8 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
false,
|
||||
false,
|
||||
CPerThread,
|
||||
ThreadPerClusterRow,
|
||||
ThreadPerClusterColumn,
|
||||
GemmThreadPerClusterRow,
|
||||
GemmThreadPerClusterColumn,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
@@ -152,12 +174,9 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
{
|
||||
// input: global mem to LDS,
|
||||
// formmat is [CPerBlock,BPerBlock + BGhostRead]
|
||||
blockwise_2d_tensor_copy<BlockSize>(
|
||||
in_cb_global_desc,
|
||||
blockwise_2d_copy.run(
|
||||
p_in_global + in_cb_global_desc.Get1dIndex(c_block_data_begin, b_block_data_begin),
|
||||
in_cb_block_desc,
|
||||
p_in_block,
|
||||
in_cb_block_desc.GetLengths());
|
||||
p_in_block);
|
||||
|
||||
// weight: global mem to LDS,
|
||||
// format is [S,R,CPerBlock,KPerBlock]
|
||||
@@ -245,22 +264,6 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
p_out_global[out_knhw_global_desc.Get1dIndex(k_data, n_data, h_data, w_data)] =
|
||||
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)];
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("%u %u, k %u b %u, k_data %u n_data %u h_data %u w_data %u %f\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
k,
|
||||
b,
|
||||
k_data,
|
||||
n_data,
|
||||
h_data,
|
||||
w_data,
|
||||
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user