mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
Add xdlops v4r4r4 into online compilation (#48)
* init for v4r4 xdlops olc * refactor wrap * init impl of v4r4 nchw xdlops olc * tuning * test perf * fixed v4r4 nhwc * tuned v4r4 nhwc * use gridwise_gemm_xdlops_v2r3 * swap a/b * add pointer support into offline v2r3 * debugging v4r4r4 transform for olc * change timer of olc * refactor v4r4 xdlops nchw olc * remove transform fun in v4r4 xdlops nhwc olc Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -50,6 +50,146 @@ static tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw default_tunable_dyn_conv_fwd_v4r
|
||||
{0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2},
|
||||
5, 1};
|
||||
|
||||
struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
|
||||
{
|
||||
ck::index_t BlockSize; // usually not tunable
|
||||
|
||||
ck::index_t MPerBlock;
|
||||
ck::index_t NPerBlock;
|
||||
ck::index_t KPerBlock;
|
||||
|
||||
ck::index_t MPerWave;
|
||||
ck::index_t NPerWave;
|
||||
ck::index_t K1;
|
||||
|
||||
ck::index_t MRepeat;
|
||||
ck::index_t NRepeat;
|
||||
|
||||
std::array<ck::index_t, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
|
||||
std::array<ck::index_t, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
|
||||
std::array<ck::index_t, 3> ABlockTransferThreadClusterArrangeOrder;
|
||||
std::array<ck::index_t, 3> ABlockTransferSrcAccessOrder;
|
||||
ck::index_t ABlockTransferSrcVectorDim;
|
||||
ck::index_t ABlockTransferSrcScalarPerVector;
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1;
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<ck::index_t, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
|
||||
std::array<ck::index_t, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
|
||||
std::array<ck::index_t, 3> BBlockTransferThreadClusterArrangeOrder;
|
||||
std::array<ck::index_t, 3> BBlockTransferSrcAccessOrder;
|
||||
ck::index_t BBlockTransferSrcVectorDim;
|
||||
ck::index_t BBlockTransferSrcScalarPerVector;
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1;
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<ck::index_t, 8> CThreadTransferSrcDstAccessOrder;
|
||||
ck::index_t CThreadTransferSrcDstVectorDim;
|
||||
ck::index_t CThreadTransferDstScalarPerVector;
|
||||
};
|
||||
|
||||
static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
|
||||
default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw = {
|
||||
256, // BlockSize
|
||||
128, // MPerBlock,
|
||||
128, // NPerBlock,
|
||||
4, // KPerBlock,
|
||||
32, // MPerWave,
|
||||
32, // NPerWave,
|
||||
4, // K1,
|
||||
2, // MRepeat,
|
||||
2, // NRepeat,
|
||||
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector,
|
||||
4, // ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
{0, 2, 1}, // BBlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
|
||||
1, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
4, // BBlockTransferDstScalarPerVector_K1
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun
|
||||
{3, 0, 1, 2, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
|
||||
7, // CThreadTransferSrcDstVectorDim,
|
||||
1 // CThreadTransferDstScalarPerVector
|
||||
};
|
||||
|
||||
struct tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
|
||||
{
|
||||
ck::index_t BlockSize; // usually not tunable
|
||||
|
||||
ck::index_t MPerBlock;
|
||||
ck::index_t NPerBlock;
|
||||
ck::index_t KPerBlock;
|
||||
|
||||
ck::index_t MPerWave;
|
||||
ck::index_t NPerWave;
|
||||
ck::index_t K1;
|
||||
|
||||
ck::index_t MRepeat;
|
||||
ck::index_t NRepeat;
|
||||
|
||||
std::array<ck::index_t, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
|
||||
std::array<ck::index_t, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
|
||||
std::array<ck::index_t, 3> ABlockTransferThreadClusterArrangeOrder;
|
||||
std::array<ck::index_t, 3> ABlockTransferSrcAccessOrder;
|
||||
ck::index_t ABlockTransferSrcVectorDim;
|
||||
ck::index_t ABlockTransferSrcScalarPerVector;
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1;
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<ck::index_t, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
|
||||
std::array<ck::index_t, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
|
||||
std::array<ck::index_t, 3> BBlockTransferThreadClusterArrangeOrder;
|
||||
std::array<ck::index_t, 3> BBlockTransferSrcAccessOrder;
|
||||
ck::index_t BBlockTransferSrcVectorDim;
|
||||
ck::index_t BBlockTransferSrcScalarPerVector;
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1;
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<ck::index_t, 8> CThreadTransferSrcDstAccessOrder;
|
||||
ck::index_t CThreadTransferSrcDstVectorDim;
|
||||
ck::index_t CThreadTransferDstScalarPerVector;
|
||||
};
|
||||
|
||||
static tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
|
||||
default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk = {
|
||||
256, // BlockSize
|
||||
128, // MPerBlock,
|
||||
128, // NPerBlock,
|
||||
4, // KPerBlock,
|
||||
32, // MPerWave,
|
||||
32, // NPerWave,
|
||||
4, // K1,
|
||||
2, // MRepeat,
|
||||
2, // NRepeat,
|
||||
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
4, // ABlockTransferSrcScalarPerVector,
|
||||
4, // ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
{1, 0, 2}, // BBlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
4, // BBlockTransferSrcScalarPerVector
|
||||
4, // BBlockTransferDstScalarPerVector_K1
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun
|
||||
{2, 3, 0, 1, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
|
||||
7, // CThreadTransferSrcDstVectorDim,
|
||||
1 // CThreadTransferDstScalarPerVector
|
||||
};
|
||||
|
||||
struct tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw
|
||||
{
|
||||
ck::index_t BlockSize;
|
||||
|
||||
@@ -273,6 +273,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
|
||||
@@ -245,6 +245,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
|
||||
@@ -1,283 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 4;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
#if 1
|
||||
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad
|
||||
#else
|
||||
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1
|
||||
#endif
|
||||
<TInWei, GemmMPerBlock, GemmNPerBlock, GemmMPerWave, GemmNPerWave, GemmKPack>(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
#if 0
|
||||
float ave_time = launch_kernel_dynamic_gemm_xdlops_v1
|
||||
#else
|
||||
float ave_time = launch_kernel_dynamic_gemm_xdlops_v2
|
||||
#endif
|
||||
<BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(descs[I0]),
|
||||
decltype(descs[I1]),
|
||||
decltype(descs[I2]),
|
||||
decltype(descs[I3]),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmKPack,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmABlockTransferDstScalarPerVector_KPack,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<1, 0, 2>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_KPack,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused
|
||||
// with MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<2, 3, 0, 1>,
|
||||
3,
|
||||
GemmCThreadTransferDstScalarPerVector_GemmN1,
|
||||
decltype(descs[I4]),
|
||||
decltype(descs[I5]),
|
||||
decltype(descs[I6]),
|
||||
decltype(descs[I7]),
|
||||
decltype(descs[I8])>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
descs[I0],
|
||||
descs[I1],
|
||||
descs[I2],
|
||||
descs[I3],
|
||||
descs[I4],
|
||||
descs[I5],
|
||||
descs[I6],
|
||||
descs[I7],
|
||||
descs[I8],
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_dynamic_gemm_xdlops_v2r2.hpp"
|
||||
#include "driver_dynamic_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
@@ -56,63 +56,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -120,12 +64,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
@@ -139,34 +83,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
@@ -200,10 +116,18 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
@@ -216,7 +140,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_gemm_xdlops_v2r2<
|
||||
float ave_time = driver_dynamic_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
@@ -230,6 +154,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
@@ -248,26 +173,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<3, 0, 1, 2>,
|
||||
3,
|
||||
Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
||||
out_m0_m1_m2_n_grid_iterator_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
nrepeat);
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
false>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
||||
out_m0_m1_m2_n_grid_iterator_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
|
||||
|
||||
@@ -1,240 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_dynamic_gemm_xdlops_v2r2.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_gemm_xdlops_v2r2<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1>,
|
||||
2,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
||||
out_m0_m1_m2_n_grid_iterator_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
@@ -1,305 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_dynamic_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
||||
out_m0_m1_m2_n_grid_iterator_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
@@ -223,34 +223,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
@@ -325,6 +297,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
|
||||
@@ -0,0 +1,376 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
|
||||
#include "olc_driver_common.hpp"
|
||||
#include "conv_tunables.hpp"
|
||||
|
||||
#include "handle.hpp"
|
||||
|
||||
namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw {
|
||||
|
||||
template <typename TInWei, typename TAcc, typename TOut>
|
||||
static std::string get_network_config_string_from_types()
|
||||
{
|
||||
std::string out;
|
||||
|
||||
out += static_cast<char>(Driver::get_typeid_from_type<TInWei>()) +
|
||||
static_cast<char>(Driver::get_typeid_from_type<TAcc>()) +
|
||||
static_cast<char>(Driver::get_typeid_from_type<TOut>());
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
static std::string
|
||||
get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* pt)
|
||||
{
|
||||
std::string out("TUN_");
|
||||
|
||||
out += std::to_string(pt->BlockSize) + "_";
|
||||
|
||||
out += std::to_string(pt->MPerBlock) + "x" + std::to_string(pt->NPerBlock) + "x" +
|
||||
std::to_string(pt->KPerBlock) + "_";
|
||||
out += std::to_string(pt->MPerWave) + "x" + std::to_string(pt->NPerWave) + "x" +
|
||||
std::to_string(pt->MRepeat) + "x" + std::to_string(pt->NRepeat) + "x" +
|
||||
std::to_string(pt->K1) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferSrcVectorDim) + "_";
|
||||
out += std::to_string(pt->ABlockTransferSrcScalarPerVector) + "_";
|
||||
out += std::to_string(pt->ABlockTransferDstScalarPerVector_K1) + "_";
|
||||
out += std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferSrcVectorDim) + "_";
|
||||
out += std::to_string(pt->BBlockTransferSrcScalarPerVector) + "_";
|
||||
out += std::to_string(pt->BBlockTransferDstScalarPerVector_K1) + "_";
|
||||
out += std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun) + "_";
|
||||
|
||||
out += std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]) + "_";
|
||||
|
||||
out += std::to_string(pt->CThreadTransferSrcDstVectorDim) + "_";
|
||||
out += std::to_string(pt->CThreadTransferDstScalarPerVector);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
template <typename TInWei, typename TAcc, typename TOut>
|
||||
static std::string get_definition_string_from_types()
|
||||
{
|
||||
std::string out;
|
||||
|
||||
out += " -DCK_PARAM_IN_WEI_DATATYPE=" + std::to_string(Driver::get_typeid_from_type<TInWei>()) +
|
||||
" -DCK_PARAM_CONV_COMPTYPE=" + std::to_string(Driver::get_typeid_from_type<TAcc>()) +
|
||||
" -DCK_PARAM_OUT_DATATYPE=" + std::to_string(Driver::get_typeid_from_type<TOut>());
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
static std::string
|
||||
get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* pt)
|
||||
{
|
||||
std::string out;
|
||||
|
||||
out += " -DCK_PARAM_BlockSize=" + std::to_string(pt->BlockSize);
|
||||
|
||||
out += " -DCK_PARAM_MPerBlock=" + std::to_string(pt->MPerBlock) +
|
||||
" -DCK_PARAM_NPerBlock=" + std::to_string(pt->NPerBlock) +
|
||||
" -DCK_PARAM_KPerBlock=" + std::to_string(pt->KPerBlock);
|
||||
out += " -DCK_PARAM_MPerWave=" + std::to_string(pt->MPerWave) +
|
||||
" -DCK_PARAM_NPerWave=" + std::to_string(pt->NPerWave) +
|
||||
" -DCK_PARAM_K1=" + std::to_string(pt->K1) +
|
||||
" -DCK_PARAM_MRepeat=" + std::to_string(pt->MRepeat) +
|
||||
" -DCK_PARAM_NRepeat=" + std::to_string(pt->NRepeat);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1=" +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1=" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferThreadClusterArrangeOrder=" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferSrcAccessOrder=" +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[2]);
|
||||
|
||||
out +=
|
||||
" -DCK_PARAM_ABlockTransferSrcVectorDim=" + std::to_string(pt->ABlockTransferSrcVectorDim);
|
||||
out += " -DCK_PARAM_ABlockTransferSrcScalarPerVector=" +
|
||||
std::to_string(pt->ABlockTransferSrcScalarPerVector);
|
||||
out += " -DCK_PARAM_ABlockTransferDstScalarPerVector_K1=" +
|
||||
std::to_string(pt->ABlockTransferDstScalarPerVector_K1);
|
||||
out += " -DCK_PARAM_AThreadTransferSrcResetCoordinateAfterRun=" +
|
||||
std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1=" +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1=" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferThreadClusterArrangeOrder=" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferSrcAccessOrder=" +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[2]);
|
||||
|
||||
out +=
|
||||
" -DCK_PARAM_BBlockTransferSrcVectorDim=" + std::to_string(pt->BBlockTransferSrcVectorDim);
|
||||
out += " -DCK_PARAM_BBlockTransferSrcScalarPerVector=" +
|
||||
std::to_string(pt->BBlockTransferSrcScalarPerVector);
|
||||
out += " -DCK_PARAM_BBlockTransferDstScalarPerVector_K1=" +
|
||||
std::to_string(pt->BBlockTransferDstScalarPerVector_K1);
|
||||
out += " -DCK_PARAM_BThreadTransferSrcResetCoordinateAfterRun=" +
|
||||
std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
out += " -DCK_PARAM_CThreadTransferSrcDstAccessOrder=" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]);
|
||||
|
||||
out += " -DCK_PARAM_CThreadTransferSrcDstVectorDim=" +
|
||||
std::to_string(pt->CThreadTransferSrcDstVectorDim);
|
||||
out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" +
|
||||
std::to_string(pt->CThreadTransferDstScalarPerVector);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
} // namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_olc(
|
||||
olCompile::Handle* handle,
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
const tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* tunable,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
using namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw;
|
||||
using size_t = std::size_t;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
||||
|
||||
const auto n = in_n_c_hi_wi_desc.GetLength(I0);
|
||||
const auto c = in_n_c_hi_wi_desc.GetLength(I1);
|
||||
const auto hi = in_n_c_hi_wi_desc.GetLength(I2);
|
||||
const auto wi = in_n_c_hi_wi_desc.GetLength(I3);
|
||||
const auto k = wei_k_c_y_x_desc.GetLength(I0);
|
||||
const auto y = wei_k_c_y_x_desc.GetLength(I2);
|
||||
const auto x = wei_k_c_y_x_desc.GetLength(I3);
|
||||
const auto ho = out_n_k_ho_wo_desc.GetLength(I2);
|
||||
const auto wo = out_n_k_ho_wo_desc.GetLength(I3);
|
||||
|
||||
const auto M = k;
|
||||
const auto N = n * ho * wo;
|
||||
const auto K = c * y * x;
|
||||
const auto K0 = K / tunable->K1;
|
||||
|
||||
const index_t grid_size = (M / tunable->MPerBlock) * (N / tunable->NPerBlock);
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// these buffers are usually provided by the user application
|
||||
DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
// these are workspace buffers that should be expressed to the user by the corresponding
|
||||
// workspace API
|
||||
DeviceMem workspace_buf(4096);
|
||||
|
||||
void* a_k_m0_m1_grid_desc_dev_buf = workspace_buf.GetDeviceBuffer();
|
||||
void* b_k_n0_n1_grid_desc_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 1024);
|
||||
void* c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 2048);
|
||||
void* c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 3072);
|
||||
|
||||
const std::vector<size_t> vld = {static_cast<size_t>(tunable->BlockSize), 1, 1};
|
||||
const std::vector<size_t> vgd1 = {static_cast<size_t>(tunable->BlockSize), 1, 1};
|
||||
const std::vector<size_t> vgd2 = {static_cast<size_t>(grid_size * tunable->BlockSize), 1, 1};
|
||||
|
||||
std::string program_name =
|
||||
"dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp";
|
||||
std::string algo_name = "implicit_gemm_conv_fwd_v4r4_xdlops_nchw";
|
||||
|
||||
std::string param = " -std=c++17 ";
|
||||
std::string network_config;
|
||||
|
||||
param += get_definition_string_from_types<TInWei, TAcc, TOut>() + " " + " -DCK_USE_AMD_XDLOPS" +
|
||||
get_definition_string_from_tunable(tunable);
|
||||
|
||||
network_config = get_network_config_string_from_types<TInWei, TAcc, TOut>() + "_" +
|
||||
get_network_config_string_from_tunable(tunable);
|
||||
|
||||
std::vector<float> kernel1_times;
|
||||
std::vector<float> kernel2_times;
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
KernelTimer timer1, timer2;
|
||||
std::string kernel_name;
|
||||
|
||||
kernel_name =
|
||||
"dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare";
|
||||
auto network_config_1 = network_config + "_1";
|
||||
|
||||
timer1.Start();
|
||||
handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)(
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I0]),
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I1]),
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I2]),
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I3]),
|
||||
static_cast<index_t>(wei_k_c_y_x_lengths[I0]),
|
||||
static_cast<index_t>(wei_k_c_y_x_lengths[I2]),
|
||||
static_cast<index_t>(wei_k_c_y_x_lengths[I3]),
|
||||
conv_strides[I0],
|
||||
conv_strides[I1],
|
||||
conv_dilations[I0],
|
||||
conv_dilations[I1],
|
||||
in_left_pads[I0],
|
||||
in_left_pads[I1],
|
||||
in_right_pads[I0],
|
||||
in_right_pads[I1],
|
||||
a_k_m0_m1_grid_desc_dev_buf,
|
||||
b_k_n0_n1_grid_desc_dev_buf,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf);
|
||||
timer1.End();
|
||||
|
||||
kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw";
|
||||
auto network_config_2 = network_config + "_2";
|
||||
|
||||
timer2.Start();
|
||||
handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)(
|
||||
reinterpret_cast<const TInWei*>(wei_k_c_y_x_dev_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<const TInWei*>(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<TOut*>(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()),
|
||||
(const void*)(a_k_m0_m1_grid_desc_dev_buf),
|
||||
(const void*)(b_k_n0_n1_grid_desc_dev_buf),
|
||||
(const void*)(c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf),
|
||||
(const void*)(c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf));
|
||||
timer2.End();
|
||||
|
||||
kernel1_times.push_back(timer1.GetElapsedTime());
|
||||
kernel2_times.push_back(timer2.GetElapsedTime());
|
||||
}
|
||||
|
||||
{
|
||||
auto ave_time1 = Driver::get_effective_average(kernel1_times);
|
||||
auto ave_time2 = Driver::get_effective_average(kernel2_times);
|
||||
|
||||
const auto N = in_n_c_hi_wi_lengths[I0];
|
||||
const auto C = in_n_c_hi_wi_lengths[I1];
|
||||
|
||||
const auto K = out_n_k_ho_wo_lengths[I1];
|
||||
const auto Ho = out_n_k_ho_wo_lengths[I2];
|
||||
const auto Wo = out_n_k_ho_wo_lengths[I3];
|
||||
|
||||
const auto Y = wei_k_c_y_x_lengths[I2];
|
||||
const auto X = wei_k_c_y_x_lengths[I3];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2);
|
||||
|
||||
std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", "
|
||||
<< ave_time2 << "), " << perf << " TFlop/s" << std::endl;
|
||||
};
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,379 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
#include "olc_driver_common.hpp"
|
||||
#include "conv_tunables.hpp"
|
||||
|
||||
#include "handle.hpp"
|
||||
|
||||
namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk {
|
||||
|
||||
template <typename TInWei, typename TAcc, typename TOut>
|
||||
static std::string get_network_config_string_from_types()
|
||||
{
|
||||
std::string out;
|
||||
|
||||
out += static_cast<char>(Driver::get_typeid_from_type<TInWei>()) +
|
||||
static_cast<char>(Driver::get_typeid_from_type<TAcc>()) +
|
||||
static_cast<char>(Driver::get_typeid_from_type<TOut>());
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
static std::string
|
||||
get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* pt)
|
||||
{
|
||||
std::string out("TUN_");
|
||||
|
||||
out += std::to_string(pt->BlockSize) + "_";
|
||||
|
||||
out += std::to_string(pt->MPerBlock) + "x" + std::to_string(pt->NPerBlock) + "x" +
|
||||
std::to_string(pt->KPerBlock) + "_";
|
||||
out += std::to_string(pt->MPerWave) + "x" + std::to_string(pt->NPerWave) + "x" +
|
||||
std::to_string(pt->MRepeat) + "x" + std::to_string(pt->NRepeat) + "x" +
|
||||
std::to_string(pt->K1) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferSrcVectorDim) + "_";
|
||||
out += std::to_string(pt->ABlockTransferSrcScalarPerVector) + "_";
|
||||
out += std::to_string(pt->ABlockTransferDstScalarPerVector_K1) + "_";
|
||||
out += std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferSrcVectorDim) + "_";
|
||||
out += std::to_string(pt->BBlockTransferSrcScalarPerVector) + "_";
|
||||
out += std::to_string(pt->BBlockTransferDstScalarPerVector_K1) + "_";
|
||||
out += std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun) + "_";
|
||||
|
||||
out += std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]) + "_";
|
||||
|
||||
out += std::to_string(pt->CThreadTransferSrcDstVectorDim) + "_";
|
||||
out += std::to_string(pt->CThreadTransferDstScalarPerVector);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
template <typename TInWei, typename TAcc, typename TOut>
|
||||
static std::string get_definition_string_from_types()
|
||||
{
|
||||
std::string out;
|
||||
|
||||
out += " -DCK_PARAM_IN_WEI_DATATYPE=" + std::to_string(Driver::get_typeid_from_type<TInWei>()) +
|
||||
" -DCK_PARAM_CONV_COMPTYPE=" + std::to_string(Driver::get_typeid_from_type<TAcc>()) +
|
||||
" -DCK_PARAM_OUT_DATATYPE=" + std::to_string(Driver::get_typeid_from_type<TOut>());
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
static std::string
|
||||
get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* pt)
|
||||
{
|
||||
std::string out;
|
||||
|
||||
out += " -DCK_PARAM_BlockSize=" + std::to_string(pt->BlockSize);
|
||||
|
||||
out += " -DCK_PARAM_MPerBlock=" + std::to_string(pt->MPerBlock) +
|
||||
" -DCK_PARAM_NPerBlock=" + std::to_string(pt->NPerBlock) +
|
||||
" -DCK_PARAM_KPerBlock=" + std::to_string(pt->KPerBlock);
|
||||
out += " -DCK_PARAM_MPerWave=" + std::to_string(pt->MPerWave) +
|
||||
" -DCK_PARAM_NPerWave=" + std::to_string(pt->NPerWave) +
|
||||
" -DCK_PARAM_K1=" + std::to_string(pt->K1) +
|
||||
" -DCK_PARAM_MRepeat=" + std::to_string(pt->MRepeat) +
|
||||
" -DCK_PARAM_NRepeat=" + std::to_string(pt->NRepeat);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1=" +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1=" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferThreadClusterArrangeOrder=" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferSrcAccessOrder=" +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[2]);
|
||||
|
||||
out +=
|
||||
" -DCK_PARAM_ABlockTransferSrcVectorDim=" + std::to_string(pt->ABlockTransferSrcVectorDim);
|
||||
out += " -DCK_PARAM_ABlockTransferSrcScalarPerVector=" +
|
||||
std::to_string(pt->ABlockTransferSrcScalarPerVector);
|
||||
out += " -DCK_PARAM_ABlockTransferDstScalarPerVector_K1=" +
|
||||
std::to_string(pt->ABlockTransferDstScalarPerVector_K1);
|
||||
out += " -DCK_PARAM_AThreadTransferSrcResetCoordinateAfterRun=" +
|
||||
std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1=" +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1=" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferThreadClusterArrangeOrder=" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferSrcAccessOrder=" +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[2]);
|
||||
|
||||
out +=
|
||||
" -DCK_PARAM_BBlockTransferSrcVectorDim=" + std::to_string(pt->BBlockTransferSrcVectorDim);
|
||||
out += " -DCK_PARAM_BBlockTransferSrcScalarPerVector=" +
|
||||
std::to_string(pt->BBlockTransferSrcScalarPerVector);
|
||||
out += " -DCK_PARAM_BBlockTransferDstScalarPerVector_K1=" +
|
||||
std::to_string(pt->BBlockTransferDstScalarPerVector_K1);
|
||||
out += " -DCK_PARAM_BThreadTransferSrcResetCoordinateAfterRun=" +
|
||||
std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
out += " -DCK_PARAM_CThreadTransferSrcDstAccessOrder=" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]);
|
||||
|
||||
out += " -DCK_PARAM_CThreadTransferSrcDstVectorDim=" +
|
||||
std::to_string(pt->CThreadTransferSrcDstVectorDim);
|
||||
out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" +
|
||||
std::to_string(pt->CThreadTransferDstScalarPerVector);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
} // namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_olc(
|
||||
olCompile::Handle* handle,
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
const tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* tunable,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
using namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk;
|
||||
using size_t = std::size_t;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// The follow codes are only used for computing the grid_size, hasMainKBlockLoop,
|
||||
// hasDoubleTailKBlockLoop
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
|
||||
const auto n = in_n_hi_wi_c_desc.GetLength(I0);
|
||||
const auto hi = in_n_hi_wi_c_desc.GetLength(I1);
|
||||
const auto wi = in_n_hi_wi_c_desc.GetLength(I2);
|
||||
const auto c = in_n_hi_wi_c_desc.GetLength(I3);
|
||||
|
||||
const auto k = wei_k_y_x_c_desc.GetLength(I0);
|
||||
const auto y = wei_k_y_x_c_desc.GetLength(I1);
|
||||
const auto x = wei_k_y_x_c_desc.GetLength(I2);
|
||||
|
||||
const auto ho = out_n_ho_wo_k_desc.GetLength(I1);
|
||||
const auto wo = out_n_ho_wo_k_desc.GetLength(I2);
|
||||
|
||||
const auto M = k;
|
||||
const auto N = n * ho * wo;
|
||||
const auto K = c * y * x;
|
||||
const auto K0 = K / tunable->K1;
|
||||
|
||||
const index_t grid_size = (M / tunable->MPerBlock) * (N / tunable->NPerBlock);
|
||||
|
||||
// these buffers are usually provided by the user application
|
||||
DeviceMem in_n_hi_wi_c_dev_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_dev_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_dev_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_dev_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_dev_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_dev_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
// these are workspace buffers that should be expressed to the user by the corresponding
|
||||
// workspace API
|
||||
DeviceMem workspace_buf(4096);
|
||||
|
||||
void* a_k0_m_k1_grid_desc_dev_buf = workspace_buf.GetDeviceBuffer();
|
||||
void* b_k0_n_k1_grid_desc_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 1024);
|
||||
void* c_m0_m1_m2_n_grid_desc_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 2048);
|
||||
void* c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 3072);
|
||||
|
||||
const std::vector<size_t> vld = {static_cast<size_t>(tunable->BlockSize), 1, 1};
|
||||
const std::vector<size_t> vgd1 = {static_cast<size_t>(tunable->BlockSize), 1, 1};
|
||||
const std::vector<size_t> vgd2 = {static_cast<size_t>(grid_size * tunable->BlockSize), 1, 1};
|
||||
|
||||
std::string program_name =
|
||||
"dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp";
|
||||
std::string algo_name = "implicit_gemm_conv_fwd_v4r4_xdlops_nhwc";
|
||||
|
||||
std::string param = " -std=c++17 ";
|
||||
std::string network_config;
|
||||
|
||||
param += get_definition_string_from_types<TInWei, TAcc, TOut>() + " -DCK_USE_AMD_XDLOPS ";
|
||||
param += get_definition_string_from_tunable(tunable);
|
||||
|
||||
network_config = get_network_config_string_from_types<TInWei, TAcc, TOut>() + "_" +
|
||||
get_network_config_string_from_tunable(tunable);
|
||||
|
||||
std::vector<float> kernel1_times;
|
||||
std::vector<float> kernel2_times;
|
||||
|
||||
KernelTimer timer1, timer2;
|
||||
std::string kernel_name;
|
||||
|
||||
kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare";
|
||||
auto network_config_1 = network_config + "_1";
|
||||
|
||||
timer1.Start();
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)(
|
||||
static_cast<index_t>(in_n_hi_wi_c_lengths[I0]),
|
||||
static_cast<index_t>(in_n_hi_wi_c_lengths[I1]),
|
||||
static_cast<index_t>(in_n_hi_wi_c_lengths[I2]),
|
||||
static_cast<index_t>(in_n_hi_wi_c_lengths[I3]),
|
||||
static_cast<index_t>(wei_k_y_x_c_lengths[I0]),
|
||||
static_cast<index_t>(wei_k_y_x_c_lengths[I1]),
|
||||
static_cast<index_t>(wei_k_y_x_c_lengths[I2]),
|
||||
conv_strides[I0],
|
||||
conv_strides[I1],
|
||||
conv_dilations[I0],
|
||||
conv_dilations[I1],
|
||||
in_left_pads[I0],
|
||||
in_left_pads[I1],
|
||||
in_right_pads[I0],
|
||||
in_right_pads[I1],
|
||||
a_k0_m_k1_grid_desc_dev_buf,
|
||||
b_k0_n_k1_grid_desc_dev_buf,
|
||||
c_m0_m1_m2_n_grid_desc_dev_buf,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf);
|
||||
}
|
||||
timer1.End();
|
||||
|
||||
kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk";
|
||||
auto network_config_2 = network_config + "_2";
|
||||
|
||||
timer2.Start();
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)(
|
||||
reinterpret_cast<const TInWei*>(in_n_hi_wi_c_dev_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<const TInWei*>(wei_k_y_x_c_dev_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<TOut*>(out_n_ho_wo_k_dev_buf.GetDeviceBuffer()),
|
||||
(const void*)(a_k0_m_k1_grid_desc_dev_buf),
|
||||
(const void*)(b_k0_n_k1_grid_desc_dev_buf),
|
||||
(const void*)(c_m0_m1_m2_n_grid_desc_dev_buf),
|
||||
(const void*)(c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf));
|
||||
}
|
||||
timer2.End();
|
||||
|
||||
{
|
||||
auto ave_time1 = timer1.GetElapsedTime() / nrepeat;
|
||||
auto ave_time2 = timer2.GetElapsedTime() / nrepeat;
|
||||
|
||||
const auto N = in_n_hi_wi_c_lengths[I0];
|
||||
const auto C = in_n_hi_wi_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time2;
|
||||
|
||||
std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", "
|
||||
<< ave_time2 << "), " << perf << " TFlop/s" << std::endl;
|
||||
};
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_dev_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
Reference in New Issue
Block a user