mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
update with new copy op
This commit is contained in:
@@ -391,7 +391,7 @@ int main()
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 3x3, 34x34
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 256;
|
||||
@@ -587,11 +587,11 @@ int main()
|
||||
device_implicit_gemm_convolution_1_nchw_kcsr
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_1_nchw_srck_nkhw
|
||||
#elif 0
|
||||
#elif 1
|
||||
device_implicit_gemm_convolution_1_chwn_csrk_khwn
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_2_cnhw_srck_knhw
|
||||
#elif 1
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_2_cnhw_csrk_knhw
|
||||
#endif
|
||||
(in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat);
|
||||
@@ -608,7 +608,7 @@ int main()
|
||||
nrepeat);
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
if(S == 3 && R == 3)
|
||||
{
|
||||
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
|
||||
|
||||
@@ -87,7 +87,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 8;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
@@ -101,6 +101,12 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 2; // not used, yet
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 0
|
||||
// 3x3 58x58, NKC = 16,256,128
|
||||
@@ -162,7 +168,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// for 1x1, 28x28
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 128;
|
||||
@@ -176,6 +182,12 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#endif
|
||||
|
||||
@@ -211,7 +223,11 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread>
|
||||
WoPerThread,
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1,
|
||||
InBlockCopyDataPerRead,
|
||||
WeiBlockCopyDataPerRead>
|
||||
<<<grid_dim, block_dim>>>(in_chwn_desc,
|
||||
static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
|
||||
wei_csrk_desc,
|
||||
|
||||
@@ -108,6 +108,9 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(InDesc,
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 0
|
||||
// 3x3 58x58, NKC = 16,256,128
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
#include "ConstantMatrixDescriptor.cuh"
|
||||
#include "blockwise_4d_tensor_op.cuh"
|
||||
#include "blockwise_2d_tensor_op.cuh"
|
||||
#include "threadwise_4d_tensor_op.cuh"
|
||||
#include "blockwise_gemm.cuh"
|
||||
|
||||
@@ -21,7 +22,11 @@ template <unsigned GridSize,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread,
|
||||
unsigned HoPerThread,
|
||||
unsigned WoPerThread>
|
||||
unsigned WoPerThread,
|
||||
unsigned WeiBlockCopyThreadPerDim0,
|
||||
unsigned WeiBlockCopyThreadPerDim1,
|
||||
unsigned InBlockCopyDataPerRead,
|
||||
unsigned WeiBlockCopyDataPerRead>
|
||||
__global__ void
|
||||
gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
|
||||
Float* const __restrict__ p_in_global,
|
||||
@@ -80,12 +85,19 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
|
||||
const unsigned hi_block_data_begin = ho_block_data_begin;
|
||||
const unsigned wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// flattend (2d) tensor view of gridwise weight
|
||||
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * S * R, K>{});
|
||||
|
||||
// tensor view of blockwise input and weight in LDS
|
||||
// be careful of alignment
|
||||
constexpr auto in_chwn_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{});
|
||||
|
||||
constexpr auto wei_csrk_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<CPerBlock, S, R, KPerBlock>{});
|
||||
constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock * S * R, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr auto wei_csrk_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, S, R, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_hkwn_thread_desc =
|
||||
@@ -112,13 +124,31 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
|
||||
decltype(in_chwn_block_desc),
|
||||
decltype(in_chwn_block_desc.GetLengths())>{};
|
||||
|
||||
// weight: format is [S,R,C,K]
|
||||
constexpr auto blockwise_wei_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock*S*R,KPerBlock]
|
||||
#if 0
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_csrk_global_desc),
|
||||
decltype(wei_csrk_block_desc),
|
||||
decltype(wei_csrk_block_desc.GetLengths())>{};
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths())>{};
|
||||
#elif 0
|
||||
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1>{};
|
||||
#elif 1
|
||||
const auto blockwise_wei_copy = Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
@@ -155,12 +185,17 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
|
||||
CPerThread,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace();
|
||||
constexpr unsigned wei_block_size = wei_csrk_block_desc.GetElementSpace();
|
||||
// LDS: be careful of alignment
|
||||
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace();
|
||||
constexpr unsigned wei_block_size =
|
||||
wei_csrk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
__shared__ Float p_in_block[in_block_size];
|
||||
__shared__ Float p_wei_block[wei_block_size];
|
||||
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
|
||||
? InBlockCopyDataPerRead
|
||||
: WeiBlockCopyDataPerRead;
|
||||
|
||||
__shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)];
|
||||
__shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()];
|
||||
|
||||
Reference in New Issue
Block a user