mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
Dynamic tensor descriptor (#24)
* support dynamic tensor descriptor * use buffer load OOB feature for padding case * add navi support * add int8x4 inference kernel Co-authored-by: Chao Liu <chao@ixt-rack-81.local.lan> Co-authored-by: Jing Zhang <jizhan@amd.com>
This commit is contained in:
@@ -51,26 +51,24 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(
|
||||
}
|
||||
|
||||
template <class InDesc, class WeiDesc, class OutDesc>
|
||||
constexpr std::size_t calculate_convolution_flops(InDesc, WeiDesc, OutDesc)
|
||||
constexpr std::size_t
|
||||
calculate_convolution_flops(const InDesc& in_desc, const WeiDesc& wei_desc, const OutDesc& out_desc)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
constexpr auto out_desc = OutDesc{};
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t N = out_desc.GetLength(I0);
|
||||
constexpr index_t K = out_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_desc.GetLength(I3);
|
||||
const index_t N = out_desc.GetLength(I0);
|
||||
const index_t K = out_desc.GetLength(I1);
|
||||
const index_t Ho = out_desc.GetLength(I2);
|
||||
const index_t Wo = out_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t C = wei_desc.GetLength(I1);
|
||||
constexpr index_t Y = wei_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_desc.GetLength(I3);
|
||||
const index_t C = wei_desc.GetLength(I1);
|
||||
const index_t Y = wei_desc.GetLength(I2);
|
||||
const index_t X = wei_desc.GetLength(I3);
|
||||
|
||||
return std::size_t(2) * N * K * Ho * Wo * C * Y * X;
|
||||
}
|
||||
|
||||
@@ -183,7 +183,7 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
for(index_t i = 0; i < 1; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
|
||||
@@ -57,10 +57,41 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
// GemmABlockCopySrcDataPerRead_GemmM = 4
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
@@ -74,11 +105,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
@@ -104,11 +135,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<8, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 8>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<16, 16>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
@@ -222,7 +222,7 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk(InDesc i
|
||||
|
||||
static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id) {
|
||||
constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id);
|
||||
constexpr index_t gemm_k2 = gemm_sizes.At(4);
|
||||
constexpr index_t gemm_k2 = gemm_sizes[Number<4>{}];
|
||||
constexpr bool is_gemm_not_empty = gemm_k2 > 0;
|
||||
|
||||
// only compile and run if GEMM is no empty
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp"
|
||||
#include "gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp"
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
@@ -13,18 +13,20 @@ template <typename T,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
ck::index_t nrepeat)
|
||||
void device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
std::cout << "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw" << std::endl;
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
|
||||
@@ -133,7 +135,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -770,45 +772,46 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
using gridwise_conv = GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
EPerBlock,
|
||||
GemmNRepeat,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>;
|
||||
using gridwise_conv =
|
||||
GridwiseConvolutionForwardImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
EPerBlock,
|
||||
GemmNRepeat,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
@@ -2,7 +2,7 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
template <class T,
|
||||
class InDesc,
|
||||
@@ -12,18 +12,20 @@ template <class T,
|
||||
class ConvDilations,
|
||||
class InLeftPads,
|
||||
class InRightPads>
|
||||
void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
ck::index_t nrepeat)
|
||||
void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
std::cout << "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw" << std::endl;
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
|
||||
@@ -55,6 +57,109 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 0
|
||||
// cdata = 16, BlockSize = 64, 16x64x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 2;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 2;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 2;
|
||||
#elif 0
|
||||
// cdata = 16, BlockSize = 64, 16x64x4
|
||||
// GemmBBlockCopySrcDataPerRead_GemmN = 4
|
||||
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 2
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 2;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
|
||||
#elif 0
|
||||
// cdata = 32, BlockSize = 64, 16x128x4
|
||||
// GemmBBlockCopySrcDataPerRead_GemmN = 4
|
||||
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 256, 64x256x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -62,14 +167,14 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
@@ -86,6 +191,39 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 256, 128x128x2
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 256, 128x128x4
|
||||
@@ -99,10 +237,10 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
@@ -122,6 +260,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
// b threadwise copy 4x1
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
@@ -152,6 +291,40 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
// b threadwise copy 2x2
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
@@ -255,7 +428,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x16
|
||||
// GemmBBlockCopySrcDataPerRead_GemmN = 4
|
||||
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
|
||||
@@ -289,6 +462,41 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x16
|
||||
// GemmBBlockCopySrcDataPerRead_GemmN = 4
|
||||
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 8>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<16, 16>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 128x64x4
|
||||
@@ -826,7 +1034,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 64, 64x64x3
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
@@ -968,7 +1176,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
using gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw<
|
||||
using gridwise_conv = GridwiseConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
TDevice,
|
||||
@@ -0,0 +1,207 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
template <class T,
|
||||
class InDesc,
|
||||
class WeiDesc,
|
||||
class OutDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class InLeftPads,
|
||||
class InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
std::cout << "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk" << std::endl;
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto N = OutDesc::GetLengths()[I0];
|
||||
constexpr auto K = OutDesc::GetLengths()[I1];
|
||||
constexpr auto C = WeiDesc::GetLengths()[I1];
|
||||
|
||||
constexpr auto Hi = InDesc::GetLengths()[I2];
|
||||
constexpr auto Wi = InDesc::GetLengths()[I3];
|
||||
|
||||
constexpr auto Ho = OutDesc::GetLengths()[I2];
|
||||
constexpr auto Wo = OutDesc::GetLengths()[I3];
|
||||
|
||||
constexpr auto Y = WeiDesc::GetLengths()[I2];
|
||||
constexpr auto X = WeiDesc::GetLengths()[I3];
|
||||
|
||||
// compile-time variables
|
||||
constexpr auto in_n_hi_wi_c_desc =
|
||||
make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{});
|
||||
constexpr auto wei_k_y_x_c_desc = make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{});
|
||||
constexpr auto out_n_ho_wo_k_desc =
|
||||
make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{});
|
||||
|
||||
Tensor<float> in_nhwc(
|
||||
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{})));
|
||||
Tensor<float> wei_kyxc(
|
||||
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{})));
|
||||
Tensor<float> out_nhwk(
|
||||
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{})));
|
||||
|
||||
auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) {
|
||||
in_nhwc(n, hi, wi, c) = in_nchw(n, c, hi, wi);
|
||||
};
|
||||
|
||||
auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) {
|
||||
wei_kyxc(k, y, x, c) = wei_kcyx(k, c, y, x);
|
||||
};
|
||||
|
||||
auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) {
|
||||
out_nhwk(n, ho, wo, k) = out_nkhw(n, k, ho, wo);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)(std::thread::hardware_concurrency());
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
|
||||
DeviceMem in_nhwc_device_buf(data_sz * in_nhwc.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kyxc_device_buf(data_sz * wei_kyxc.mDesc.GetElementSpace());
|
||||
DeviceMem out_nhwk_device_buf(data_sz * out_nhwk.mDesc.GetElementSpace());
|
||||
|
||||
in_nhwc_device_buf.ToDevice(in_nhwc.mData.data());
|
||||
wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data());
|
||||
out_nhwk_device_buf.ToDevice(out_nhwk.mData.data());
|
||||
|
||||
#if 1
|
||||
// cdata = 16, BlockSize = 64, 16x64x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 2;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 2;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmK = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmM1 = 2;
|
||||
#endif
|
||||
|
||||
constexpr index_t GemmM = K;
|
||||
constexpr index_t GemmN = N * Ho * Wo;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
using gridwise_conv = GridwiseConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
TDevice,
|
||||
TDevice,
|
||||
decltype(in_n_hi_wi_c_desc),
|
||||
decltype(wei_k_y_x_c_desc),
|
||||
decltype(out_n_ho_wo_k_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
ThreadGemmDataPerReadM,
|
||||
ThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmK,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmK,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmM1>;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
launch_kernel(run_gridwise_operation<gridwise_conv,
|
||||
const TDevice* const __restrict__,
|
||||
const TDevice* const __restrict__,
|
||||
TDevice* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<TDevice*>(in_nhwc_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TDevice*>(wei_kyxc_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TDevice*>(out_nhwk_device_buf.GetDeviceBuffer()));
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
out_nhwk_device_buf.FromDevice(out_nhwk.mData.data());
|
||||
|
||||
auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
out_nkhw(n, k, ho, wo) = out_nhwk(n, ho, wo, k);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)(std::thread::hardware_concurrency());
|
||||
}
|
||||
@@ -0,0 +1,508 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
template <class TInWei,
|
||||
ck::index_t InWeiVectorSize,
|
||||
class TAcc,
|
||||
class TOut,
|
||||
class InDesc,
|
||||
class WeiDesc,
|
||||
class OutDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class InLeftPads,
|
||||
class InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
|
||||
InDesc,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
WeiDesc,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
OutDesc,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw"
|
||||
<< std::endl;
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
#if 0
|
||||
// run-time variables
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths()));
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths()));
|
||||
|
||||
const auto conv_strides = to_multi_index(ConvStrides{});
|
||||
const auto conv_dilations = to_multi_index(ConvDilations{});
|
||||
const auto in_left_pads = to_multi_index(InLeftPads{});
|
||||
const auto in_right_pads = to_multi_index(InRightPads{});
|
||||
#else
|
||||
// compile-time variables
|
||||
const auto in_n_c_hi_wi_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
sequence_to_tuple_of_number(InDesc::GetLengths()));
|
||||
const auto wei_k_c_y_x_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
sequence_to_tuple_of_number(WeiDesc::GetLengths()));
|
||||
const auto out_n_k_ho_wo_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
sequence_to_tuple_of_number(OutDesc::GetLengths()));
|
||||
|
||||
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
|
||||
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
|
||||
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
|
||||
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
// cdata = 16, BlockSize = 64, 16x64x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 2;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
|
||||
#elif 0
|
||||
// cdata = 32, BlockSize 64, 16x128x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize 64, 16x256x2
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 1;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize 64, 16x256x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 1;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
|
||||
#elif 0
|
||||
// cdata = 16, BlockSize = 64, 16x64x4
|
||||
// GemmBBlockCopySrcDataPerRead_GemmN = 4
|
||||
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 2
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 2;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
|
||||
#elif 0
|
||||
// cdata = 32, BlockSize = 64, 16x128x4
|
||||
// GemmBBlockCopySrcDataPerRead_GemmN = 4
|
||||
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 32x256x8
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 32;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 256, 128x128x2
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 256, 128x128x4
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
// b thread copy 4x1
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
// b thread copy 2x2
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x16
|
||||
// GemmBBlockCopySrcDataPerRead_GemmN = 4
|
||||
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
|
||||
#endif
|
||||
|
||||
constexpr auto conv_driver =
|
||||
#if 1
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
|
||||
#elif 0
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
|
||||
#elif 1
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
|
||||
#endif
|
||||
<BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TAcc,
|
||||
TOut,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmABlockTransferDstScalarPerVector_GemmM,
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmN,
|
||||
GemmCThreadTransferDstScalarPerVector_GemmN1>{};
|
||||
|
||||
conv_driver.Run(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()));
|
||||
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,427 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
template <class TInWei,
|
||||
ck::index_t InWeiVectorSize,
|
||||
class TAcc,
|
||||
class TOut,
|
||||
class InDesc,
|
||||
class WeiDesc,
|
||||
class OutDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class InLeftPads,
|
||||
class InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
|
||||
InDesc,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
WeiDesc,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
OutDesc,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk"
|
||||
<< std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto N = OutDesc::GetLengths()[I0];
|
||||
constexpr auto K = OutDesc::GetLengths()[I1];
|
||||
constexpr auto C = WeiDesc::GetLengths()[I1];
|
||||
|
||||
constexpr auto Hi = InDesc::GetLengths()[I2];
|
||||
constexpr auto Wi = InDesc::GetLengths()[I3];
|
||||
|
||||
constexpr auto Ho = OutDesc::GetLengths()[I2];
|
||||
constexpr auto Wo = OutDesc::GetLengths()[I3];
|
||||
|
||||
constexpr auto Y = WeiDesc::GetLengths()[I2];
|
||||
constexpr auto X = WeiDesc::GetLengths()[I3];
|
||||
|
||||
constexpr auto C0 = C / Number<InWeiVectorSize>{};
|
||||
constexpr auto C1 = Number<InWeiVectorSize>{};
|
||||
|
||||
#if 0
|
||||
// run-time variables
|
||||
constexpr auto in_n_hi_wi_c0_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0));
|
||||
constexpr auto wei_k_y_x_c0_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(K, Y, X, C0));
|
||||
constexpr auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Ho, Wo, K));
|
||||
|
||||
const auto conv_strides = to_multi_index(ConvStrides{});
|
||||
const auto conv_dilations = to_multi_index(ConvDilations{});
|
||||
const auto in_left_pads = to_multi_index(InLeftPads{});
|
||||
const auto in_right_pads = to_multi_index(InRightPads{});
|
||||
#else
|
||||
// compile-time variables
|
||||
constexpr auto in_n_hi_wi_c0_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Hi, Wi, C0));
|
||||
constexpr auto wei_k_y_x_c0_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y, X, C0));
|
||||
constexpr auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Ho, Wo, K));
|
||||
|
||||
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
|
||||
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
|
||||
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
|
||||
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
|
||||
#endif
|
||||
|
||||
Tensor<TInWei> in_n_hi_wi_c(
|
||||
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{})));
|
||||
Tensor<TInWei> wei_k_y_x_c(
|
||||
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{})));
|
||||
Tensor<TOut> out_n_ho_wo_k(
|
||||
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{})));
|
||||
|
||||
auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) {
|
||||
in_n_hi_wi_c(n, hi, wi, c) = in_n_c_hi_wi(n, c, hi, wi);
|
||||
};
|
||||
|
||||
auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) {
|
||||
wei_k_y_x_c(k, y, x, c) = wei_k_c_y_x(k, c, y, x);
|
||||
};
|
||||
|
||||
auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) {
|
||||
out_n_ho_wo_k(n, ho, wo, k) = out_n_k_ho_wo(n, k, ho, wo);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)();
|
||||
make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)();
|
||||
make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)();
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
#if 1
|
||||
// cdata = 16, BlockSize = 64, 16x64x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 2;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 2;
|
||||
#elif 0
|
||||
// cdata = 32, BlockSize = 64, 16x128x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 2;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 64, 16x256x2
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 1;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 64, 16x256x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 1;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 32x256x4
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 32;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 32x256x8
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 32;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
#endif
|
||||
|
||||
constexpr auto conv_driver =
|
||||
#if 1
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
|
||||
#elif 0
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_no_pad
|
||||
#elif 1
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
|
||||
#endif
|
||||
<BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TAcc,
|
||||
TOut,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmABlockTransferDstScalarPerVector_GemmM,
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmN,
|
||||
GemmCThreadTransferDstScalarPerVector_GemmM1>{};
|
||||
|
||||
conv_driver.Run(wei_k_y_x_c0_desc,
|
||||
in_n_hi_wi_c0_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()));
|
||||
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
out_n_k_ho_wo(n, k, ho, wo) = out_n_ho_wo_k(n, ho, wo, k);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)();
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
template <class TInWei,
|
||||
ck::index_t InWeiVectorSize,
|
||||
class TAcc,
|
||||
class TOut,
|
||||
class InDesc,
|
||||
class WeiDesc,
|
||||
class OutDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class InLeftPads,
|
||||
class InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
|
||||
InDesc,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
WeiDesc,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
OutDesc,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw"
|
||||
<< std::endl;
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto N = OutDesc::GetLengths()[I0];
|
||||
constexpr auto K = OutDesc::GetLengths()[I1];
|
||||
constexpr auto C = WeiDesc::GetLengths()[I1];
|
||||
|
||||
constexpr auto Hi = InDesc::GetLengths()[I2];
|
||||
constexpr auto Wi = InDesc::GetLengths()[I3];
|
||||
|
||||
constexpr auto Ho = OutDesc::GetLengths()[I2];
|
||||
constexpr auto Wo = OutDesc::GetLengths()[I3];
|
||||
|
||||
constexpr auto Y = WeiDesc::GetLengths()[I2];
|
||||
constexpr auto X = WeiDesc::GetLengths()[I3];
|
||||
|
||||
constexpr auto C0 = C / Number<InWeiVectorSize>{};
|
||||
constexpr auto C1 = Number<InWeiVectorSize>{};
|
||||
|
||||
#if 0
|
||||
// run-time variables
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths()));
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths()));
|
||||
|
||||
const auto conv_strides = to_multi_index(ConvStrides{});
|
||||
const auto conv_dilations = to_multi_index(ConvDilations{});
|
||||
const auto in_left_pads = to_multi_index(InLeftPads{});
|
||||
const auto in_right_pads = to_multi_index(InRightPads{});
|
||||
#else
|
||||
// compile-time variables
|
||||
const auto in_n_c0_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi));
|
||||
const auto wei_k_c0_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X));
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo));
|
||||
|
||||
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
|
||||
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
|
||||
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
|
||||
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
|
||||
#endif
|
||||
|
||||
Tensor<TInWei> in_n_c0_hi_wi_c1(make_HostTensorDescriptor(
|
||||
make_native_tensor_descriptor_packed(Sequence<N, C0, Hi, Wi, C1>{})));
|
||||
Tensor<TInWei> wei_k_c0_y_x_c1(make_HostTensorDescriptor(
|
||||
make_native_tensor_descriptor_packed(Sequence<K, C0, Y, X, C1>{})));
|
||||
|
||||
auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) {
|
||||
in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) =
|
||||
in_n_c_hi_wi(n, c, hi, wi);
|
||||
};
|
||||
|
||||
auto f_kcyx2kc0yxc1 = [&](auto k, auto y, auto x, auto c) {
|
||||
wei_k_c0_y_x_c1(k, c / InWeiVectorSize, y, x, c % InWeiVectorSize) =
|
||||
wei_k_c_y_x(k, c, y, x);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)();
|
||||
make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)();
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
|
||||
|
||||
// cdata = 64, BlockSize = 64, 16x8x32x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 16;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
constexpr index_t EPerBlock = 4;
|
||||
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
constexpr index_t EPerThread = EPerBlock;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
|
||||
using ABlockTransferThreadClusterLengths_E_K = Sequence<EPerBlock, 16>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
|
||||
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_W = 1;
|
||||
|
||||
constexpr auto conv_driver =
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
|
||||
BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TAcc,
|
||||
TOut,
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
EPerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
ABlockTransferSrcScalarPerVector_E,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
BThreadTransferSrcScalarPerVector_W,
|
||||
CThreadTransferDstScalarPerVector_W>{};
|
||||
|
||||
conv_driver.Run(wei_k_c0_y_x_desc,
|
||||
in_n_c0_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()));
|
||||
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -273,7 +273,7 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
|
||||
std::size_t ho = HoPerTile * htile + j;
|
||||
for(int i = 0; i < WoPerTile; ++i)
|
||||
{
|
||||
std::size_t wo = WoPerTile * wtile + i;
|
||||
std::size_t wo = WoPerTile * wtile + i;
|
||||
out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,7 +158,7 @@ struct ParallelTensorFunctor
|
||||
return indices;
|
||||
}
|
||||
|
||||
void operator()(std::size_t num_thread) const
|
||||
void operator()(std::size_t num_thread = std::thread::hardware_concurrency()) const
|
||||
{
|
||||
std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread;
|
||||
|
||||
|
||||
@@ -4,10 +4,7 @@
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "print_array.hpp"
|
||||
#include "print_sequence.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
@@ -54,10 +51,10 @@ int main(int argc, char* argv[])
|
||||
#elif 0
|
||||
// 3x3, 28x28
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
@@ -156,13 +153,13 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<2, 2>;
|
||||
using RightPads = Sequence<2, 2>;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 1x7 filter, 0x3 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 7;
|
||||
|
||||
@@ -197,7 +194,7 @@ int main(int argc, char* argv[])
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<2, 2>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
@@ -211,11 +208,11 @@ int main(int argc, char* argv[])
|
||||
ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
|
||||
ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
|
||||
ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
|
||||
print_sequence("LeftPads", LeftPads{});
|
||||
print_sequence("LeftPads", LeftPads{});
|
||||
print_sequence("RightPads", RightPads{});
|
||||
print_sequence("ConvStrides", ConvStrides{});
|
||||
print_sequence("ConvDilations", ConvDilations{});
|
||||
print_array("LeftPads", LeftPads{});
|
||||
print_array("LeftPads", LeftPads{});
|
||||
print_array("RightPads", RightPads{});
|
||||
print_array("ConvStrides", ConvStrides{});
|
||||
print_array("ConvDilations", ConvDilations{});
|
||||
|
||||
Tensor<float> in_nchw_device(make_HostTensorDescriptor(in_nchw_desc));
|
||||
Tensor<float> in_nchw_host(make_HostTensorDescriptor(in_nchw_desc));
|
||||
@@ -248,7 +245,7 @@ int main(int argc, char* argv[])
|
||||
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
#elif 1
|
||||
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
|
||||
#elif 1
|
||||
device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
conv_bwd_data_driver.cpp
|
||||
@@ -5,27 +5,29 @@
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "print_array.hpp"
|
||||
#include "print_sequence.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
#if 0
|
||||
// 1x1, 17x17
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 1080;
|
||||
constexpr index_t WI = 1920;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
@@ -35,6 +37,135 @@ int main(int argc, char* argv[])
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 540;
|
||||
constexpr index_t WI = 960;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 270;
|
||||
constexpr index_t WI = 480;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 1080;
|
||||
constexpr index_t WI = 1920;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 1;
|
||||
constexpr index_t HI = 1024;
|
||||
constexpr index_t WI = 2048;
|
||||
constexpr index_t K = 4;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 540;
|
||||
constexpr index_t WI = 960;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 270;
|
||||
constexpr index_t WI = 480;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 3x3, 36x36, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 37;
|
||||
constexpr index_t WI = 37;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 35x35, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 71x71
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 71;
|
||||
constexpr index_t WI = 71;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 1
|
||||
// 1x1, 8x8
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1536;
|
||||
@@ -70,7 +201,7 @@ int main(int argc, char* argv[])
|
||||
constexpr index_t C = 96;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t K = 96;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
@@ -94,7 +225,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 7x1, 17x17
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
@@ -109,7 +240,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 1x7, 17x17
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
@@ -141,12 +272,11 @@ int main(int argc, char* argv[])
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 147x147
|
||||
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 32;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 147;
|
||||
constexpr index_t WI = 147;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
@@ -157,7 +287,6 @@ int main(int argc, char* argv[])
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 3x3, 149x149
|
||||
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 32;
|
||||
constexpr index_t HI = 149;
|
||||
@@ -201,7 +330,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 3x3, 35x35, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 288;
|
||||
@@ -244,21 +373,6 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 0>;
|
||||
using RightPads = Sequence<1, 0>;
|
||||
#elif 0
|
||||
// 3x1, 8x8
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 448;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 0>;
|
||||
using RightPads = Sequence<1, 0>;
|
||||
#elif 0
|
||||
@@ -278,7 +392,6 @@ int main(int argc, char* argv[])
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 7x1, 73x73
|
||||
// v44@v100 xx.xx%, cudnn@v100 xx.xx%
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 73;
|
||||
@@ -352,7 +465,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 3x3, 28x28
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
@@ -382,7 +495,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 1x1, 56x56, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
@@ -442,7 +555,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 1x1, 7x7
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
@@ -472,7 +585,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 1x1, 56x56
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
@@ -487,7 +600,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 3x3, 56x56
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
@@ -512,17 +625,26 @@ int main(int argc, char* argv[])
|
||||
ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
|
||||
ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
|
||||
ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
|
||||
print_sequence("LeftPads", LeftPads{});
|
||||
print_sequence("RightPads", RightPads{});
|
||||
print_sequence("ConvStrides", ConvStrides{});
|
||||
print_sequence("ConvDilations", ConvDilations{});
|
||||
print_array("LeftPads", to_multi_index(LeftPads{}));
|
||||
print_array("RightPads", to_multi_index(RightPads{}));
|
||||
print_array("ConvStrides", to_multi_index(ConvStrides{}));
|
||||
print_array("ConvDilations", to_multi_index(ConvDilations{}));
|
||||
|
||||
#if 1
|
||||
using in_data_t = float;
|
||||
using out_data_t = float;
|
||||
#else
|
||||
using in_data_t = half_float::half;
|
||||
using out_data_t = half_float::half;
|
||||
#if 0
|
||||
using in_data_t = float;
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 0
|
||||
using in_data_t = float;
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = int8_t;
|
||||
#elif 1
|
||||
using in_data_t = int8_t;
|
||||
constexpr index_t in_vector_size = 4;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
Tensor<in_data_t> in_nchw(make_HostTensorDescriptor(in_nchw_desc));
|
||||
@@ -532,14 +654,15 @@ int main(int argc, char* argv[])
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
if(argc != 3)
|
||||
if(argc != 4)
|
||||
{
|
||||
printf("arg1: do_verification, arg2: nrepeat\n");
|
||||
printf("arg1: do_verification, arg2: do_log, arg3: nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
bool do_verification = atoi(argv[1]);
|
||||
index_t nrepeat = atoi(argv[2]);
|
||||
bool do_log = atoi(argv[2]);
|
||||
index_t nrepeat = atoi(argv[3]);
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
@@ -548,9 +671,9 @@ int main(int argc, char* argv[])
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#elif 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_3{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
#elif 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread);
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#elif 1
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
@@ -565,59 +688,112 @@ int main(int argc, char* argv[])
|
||||
#endif
|
||||
}
|
||||
|
||||
#if 1
|
||||
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#if 0
|
||||
device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
out_data_t>
|
||||
|
||||
(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 1
|
||||
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
out_data_t>
|
||||
|
||||
(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 1
|
||||
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
#if 0
|
||||
if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 &&
|
||||
ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1)
|
||||
{
|
||||
host_winograd_3x3_convolution(
|
||||
in_nchw, wei_kcyx, out_nkhw_host, LeftPads{}, RightPads{});
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
host_direct_convolution(in_nchw,
|
||||
wei_kcyx,
|
||||
out_nkhw_host,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{});
|
||||
}
|
||||
host_direct_convolution(in_nchw,
|
||||
wei_kcyx,
|
||||
out_nkhw_host,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{});
|
||||
|
||||
check_error(out_nkhw_host, out_nkhw_device);
|
||||
|
||||
#if 0
|
||||
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
|
||||
#endif
|
||||
if(do_log)
|
||||
{
|
||||
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
conv_driver.cpp
|
||||
Reference in New Issue
Block a user