mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
v5r1 fusion kernels for inference (#49)
* init
* refactor for 1x1
* rename e0_e1
* add e1 with bugs
* debug
* fixed
* fixed e1
* add timer
* imprve threadwise gemm with dot2
* add e2
* tuning
* seperate c2
* add nhwc
* restore nchwc
* clean
* opt
* fixed; tuning
* add BGlobalMoveSliceWindowStepHacks{}
* tuning
* repeat running
* adjust
* merge v5r1 nchwc
* add adaptors
* split k0 k1 in c_thread_grid
* split h and w
* remove v5r1 nhwc
* clean for pr
* remove host_conv_add
* clean code
* clean
* add dynamic support
* static mode
* test static
* add conv+add fusion
* fixed validation
* naming fix
* use activ_enum
* make static
* refactor conv_add for InMem::add
* add bias
* add conv_out
* add configurable makeddesc
* add maxpool fusion
* add maxpool host for validation
* enable static desc
* conv-only use v5r1_add
* test
* test
* for binary dumps
* fixed incorrect results due to typo
* clean
* debugging maxpool
* workaround with offset trick
* clean code
* modularize ops of fusion
* add gridwise_gemm_v3
* create seperate fusion fun
* enable dynamic mode of conv and conv+resize_add
* add dynamic mode of maxpool
* add pass by point
* add activ_type as arguments
* merge develop
* clean
* reset config to old default
Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -0,0 +1,220 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
ck::ActivTypeEnum_t activ_type,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename AddLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1(
|
||||
const InLengths& in_n_c0_hi_wi_c1_lengths,
|
||||
const WeiLengths& wei_k_c0_y_x_c1_lengths,
|
||||
const AddLengths& add_n_k0_hox2_wox2_k1_lengths,
|
||||
const OutLengths& out_n_k0_ho_wo_k1_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c0_hi_wi_c1,
|
||||
const Tensor<TInWei>& wei_k_c0_y_x_c1,
|
||||
const Tensor<TOut>& bias_k0_k1,
|
||||
const Tensor<TOut>& add_n_k0_hox2_wox2_k1,
|
||||
Tensor<TOut>& add_n_k0_hox2_wox2_k1_out,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = out_n_k0_ho_wo_k1_lengths[I0];
|
||||
const auto K0 = out_n_k0_ho_wo_k1_lengths[I1];
|
||||
const auto Ho = out_n_k0_ho_wo_k1_lengths[I2];
|
||||
const auto Wo = out_n_k0_ho_wo_k1_lengths[I3];
|
||||
const auto K1 = out_n_k0_ho_wo_k1_lengths[I4];
|
||||
|
||||
const auto C0 = in_n_c0_hi_wi_c1_lengths[I1];
|
||||
const auto Hi = in_n_c0_hi_wi_c1_lengths[I2];
|
||||
const auto Wi = in_n_c0_hi_wi_c1_lengths[I3];
|
||||
const auto C1 = in_n_c0_hi_wi_c1_lengths[I4];
|
||||
|
||||
const auto K = wei_k_c0_y_x_c1_lengths[I0];
|
||||
const auto Y = wei_k_c0_y_x_c1_lengths[I2];
|
||||
const auto X = wei_k_c0_y_x_c1_lengths[I3];
|
||||
|
||||
const auto Hox2 = add_n_k0_hox2_wox2_k1_lengths[I2];
|
||||
const auto Wox2 = add_n_k0_hox2_wox2_k1_lengths[I3];
|
||||
|
||||
DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) *
|
||||
in_n_c0_hi_wi_c1.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace());
|
||||
DeviceMem bias_k0_k1_device_buf(sizeof(TOut) * bias_k0_k1.mDesc.GetElementSpace());
|
||||
DeviceMem add_n_k0_hox2_wox2_k1_device_buf(sizeof(TOut) *
|
||||
add_n_k0_hox2_wox2_k1.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
|
||||
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
|
||||
bias_k0_k1_device_buf.ToDevice(bias_k0_k1.mData.data());
|
||||
add_n_k0_hox2_wox2_k1_device_buf.ToDevice(add_n_k0_hox2_wox2_k1.mData.data());
|
||||
|
||||
constexpr index_t InWeiVectorSize = 8;
|
||||
|
||||
if(C1 % InWeiVectorSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize");
|
||||
}
|
||||
|
||||
#if 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 64;
|
||||
|
||||
constexpr index_t E1 = C0 * 9;
|
||||
constexpr index_t E2 = 1;
|
||||
constexpr index_t E1PerBlock = C0;
|
||||
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
constexpr index_t EPerThread = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>;
|
||||
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
|
||||
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_K = K1;
|
||||
#elif 1
|
||||
constexpr auto BlockSize = 64;
|
||||
|
||||
constexpr auto KPerBlock = 8;
|
||||
constexpr auto HoPerBlock = 8;
|
||||
constexpr auto WoPerBlock = 32;
|
||||
|
||||
constexpr auto E1 = 2 * 9;
|
||||
constexpr auto E2 = 1;
|
||||
constexpr auto K2 = 2;
|
||||
constexpr auto E1PerBlock = 2;
|
||||
|
||||
constexpr auto KPerThread = KPerBlock;
|
||||
constexpr auto HoPerThread = 2;
|
||||
constexpr auto WoPerThread = 2;
|
||||
constexpr auto EPerThread = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, 1, E2>;
|
||||
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 =
|
||||
Sequence<1, E1PerBlock, 1, KPerBlock, 1>;
|
||||
|
||||
constexpr auto ABlockTransferSrcScalarPerVector_E2 = E2;
|
||||
constexpr auto ABlockTransferDstScalarPerVector_E2 = E2;
|
||||
constexpr auto BThreadTransferSrcScalarPerVector_E2 = E2;
|
||||
constexpr auto CThreadTransferDstScalarPerVector_K = InWeiVectorSize;
|
||||
#endif
|
||||
|
||||
const auto in_n_c0_hi_wi_c1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2));
|
||||
const auto wei_k_c0_y_x_c1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, E2));
|
||||
const auto add_n_k0_hox2_wox2_k1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hox2, Wox2, K1));
|
||||
const auto out_n_k0_ho_wo_k1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
|
||||
|
||||
constexpr auto conv_driver =
|
||||
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_add<
|
||||
BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TAcc,
|
||||
TOut,
|
||||
E1,
|
||||
E2,
|
||||
K2,
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
E1PerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
|
||||
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
|
||||
ABlockTransferSrcScalarPerVector_E2,
|
||||
ABlockTransferDstScalarPerVector_E2,
|
||||
BThreadTransferSrcScalarPerVector_E2,
|
||||
CThreadTransferDstScalarPerVector_K,
|
||||
activ_type>{};
|
||||
|
||||
std::cerr << "conv_bias_activ_resize_add_input_"
|
||||
<< "n" << N << "c" << C0 << "h" << Hi << "w" << Wi << "c" << C1 << "_filter_k" << K
|
||||
<< "c" << C0 << "y" << Y << "x" << X << "c" << C1 << "_addout_n" << N << "k" << K0
|
||||
<< "h" << Ho * 2 << "w" << Wo * 2 << "k" << K1 << std::endl;
|
||||
|
||||
for(int i = 0; i < 5; i++)
|
||||
{
|
||||
|
||||
const auto ave_time =
|
||||
conv_driver.Run(wei_k_c0_y_x_c1_desc,
|
||||
in_n_c0_hi_wi_c1_desc,
|
||||
out_n_k0_ho_wo_k1_desc,
|
||||
add_n_k0_hox2_wox2_k1_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(bias_k0_k1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(add_n_k0_hox2_wox2_k1_device_buf.GetDeviceBuffer()),
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
float perf = static_cast<float>(std::size_t(2) * N * K * Ho * Wo * C0 * C1 * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
add_n_k0_hox2_wox2_k1_device_buf.ToDevice(add_n_k0_hox2_wox2_k1.mData.data());
|
||||
|
||||
conv_driver.Run(wei_k_c0_y_x_c1_desc,
|
||||
in_n_c0_hi_wi_c1_desc,
|
||||
out_n_k0_ho_wo_k1_desc,
|
||||
add_n_k0_hox2_wox2_k1_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(bias_k0_k1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(add_n_k0_hox2_wox2_k1_device_buf.GetDeviceBuffer()),
|
||||
0);
|
||||
|
||||
add_n_k0_hox2_wox2_k1_device_buf.FromDevice(add_n_k0_hox2_wox2_k1_out.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
ck::ActivTypeEnum_t activ_type,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1(
|
||||
const InLengths& in_n_c0_hi_wi_c1_lengths,
|
||||
const WeiLengths& wei_k_c0_y_x_c1_lengths,
|
||||
const OutLengths& out_n_k0_ho_wo_k1_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c0_hi_wi_c1,
|
||||
const Tensor<TInWei>& wei_k_c0_y_x_c1,
|
||||
const Tensor<TOut>& bias_k0_k1,
|
||||
Tensor<TOut>& out_n_k0_ho_wo_k1,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = out_n_k0_ho_wo_k1_lengths[I0];
|
||||
const auto K0 = out_n_k0_ho_wo_k1_lengths[I1];
|
||||
const auto Ho = out_n_k0_ho_wo_k1_lengths[I2];
|
||||
const auto Wo = out_n_k0_ho_wo_k1_lengths[I3];
|
||||
const auto K1 = out_n_k0_ho_wo_k1_lengths[I4];
|
||||
|
||||
const auto C0 = in_n_c0_hi_wi_c1_lengths[I1];
|
||||
const auto Hi = in_n_c0_hi_wi_c1_lengths[I2];
|
||||
const auto Wi = in_n_c0_hi_wi_c1_lengths[I3];
|
||||
const auto C1 = in_n_c0_hi_wi_c1_lengths[I4];
|
||||
|
||||
const auto K = wei_k_c0_y_x_c1_lengths[I0];
|
||||
const auto Y = wei_k_c0_y_x_c1_lengths[I2];
|
||||
const auto X = wei_k_c0_y_x_c1_lengths[I3];
|
||||
|
||||
DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) *
|
||||
in_n_c0_hi_wi_c1.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace());
|
||||
DeviceMem bias_k0_k1_device_buf(sizeof(TOut) * bias_k0_k1.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) *
|
||||
out_n_k0_ho_wo_k1.mDesc.GetElementSpace());
|
||||
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
|
||||
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
|
||||
bias_k0_k1_device_buf.ToDevice(bias_k0_k1.mData.data());
|
||||
|
||||
constexpr index_t InWeiVectorSize = 8;
|
||||
|
||||
if(C1 % InWeiVectorSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize");
|
||||
}
|
||||
|
||||
#if 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 64;
|
||||
|
||||
constexpr index_t E1 = C0 * 9;
|
||||
constexpr index_t E2 = 1;
|
||||
constexpr index_t E1PerBlock = C0;
|
||||
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
constexpr index_t EPerThread = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>;
|
||||
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
|
||||
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_K = K1;
|
||||
#elif 1
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
|
||||
constexpr index_t E1 = 2 * 9;
|
||||
constexpr index_t E2 = 1;
|
||||
constexpr index_t K2 = 2;
|
||||
constexpr index_t E1PerBlock = 2;
|
||||
|
||||
constexpr index_t KPerThread = KPerBlock;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
constexpr index_t EPerThread = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, 1, E2>;
|
||||
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 =
|
||||
Sequence<1, E1PerBlock, 1, KPerBlock, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_K = InWeiVectorSize;
|
||||
#endif
|
||||
|
||||
if(KPerThread % InWeiVectorSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize");
|
||||
}
|
||||
|
||||
const auto in_n_c0_hi_wi_c1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2));
|
||||
const auto wei_k_c0_y_x_c1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, E2));
|
||||
const auto out_n_k0_ho_wo_k1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
|
||||
|
||||
constexpr auto conv_driver =
|
||||
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_outpad<
|
||||
BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TAcc,
|
||||
TOut,
|
||||
E1,
|
||||
E2,
|
||||
K2,
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
E1PerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
|
||||
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
|
||||
ABlockTransferSrcScalarPerVector_E2,
|
||||
ABlockTransferDstScalarPerVector_E2,
|
||||
BThreadTransferSrcScalarPerVector_E2,
|
||||
CThreadTransferDstScalarPerVector_K,
|
||||
activ_type>{};
|
||||
|
||||
std::cerr << "conv_bias_activ_input_"
|
||||
<< "n" << N << "c" << C0 << "h" << Hi << "w" << Wi << "c" << C1 << "_filter_k" << K
|
||||
<< "c" << C0 << "y" << Y << "x" << X << "c" << C1 << "_convout_n" << N << "k" << K0
|
||||
<< "h" << Ho << "w" << Wo << "k" << K1 << std::endl;
|
||||
|
||||
for(int i = 0; i < 5; i++)
|
||||
{
|
||||
|
||||
const auto ave_time =
|
||||
conv_driver.Run(wei_k_c0_y_x_c1_desc,
|
||||
in_n_c0_hi_wi_c1_desc,
|
||||
out_n_k0_ho_wo_k1_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(bias_k0_k1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()),
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
float perf = static_cast<float>(std::size_t(2) * N * K * Ho * Wo * C0 * C1 * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
|
||||
}
|
||||
@@ -1,190 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
ck::index_t InWeiVectorSize,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t /* nrepeat */)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = out_n_k_ho_wo_lengths[I0];
|
||||
const auto K = out_n_k_ho_wo_lengths[I1];
|
||||
const auto C = wei_k_c_y_x_lengths[I1];
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_lengths[I2];
|
||||
const auto Wi = in_n_c_hi_wi_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_lengths[I2];
|
||||
const auto Wo = out_n_k_ho_wo_lengths[I3];
|
||||
|
||||
const auto Y = wei_k_c_y_x_lengths[I2];
|
||||
const auto X = wei_k_c_y_x_lengths[I3];
|
||||
|
||||
const auto C0 = C / Number<InWeiVectorSize>{};
|
||||
const auto C1 = Number<InWeiVectorSize>{};
|
||||
|
||||
const auto K0 = K / Number<InWeiVectorSize>{};
|
||||
const auto K1 = Number<InWeiVectorSize>{};
|
||||
|
||||
Tensor<TInWei> in_n_c0_hi_wi_c1(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{N, C0, Hi, Wi, C1}));
|
||||
Tensor<TInWei> wei_k_c0_y_x_c1(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{K, C0, Y, X, C1}));
|
||||
Tensor<TOut> out_n_k0_ho_wo_k1(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{N, K0, Ho, Wo, K1}));
|
||||
|
||||
auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) {
|
||||
in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) =
|
||||
in_n_c_hi_wi(n, c, hi, wi);
|
||||
};
|
||||
|
||||
auto f_kcyx2kc0yxc1 = [&](auto k, auto y, auto x, auto c) {
|
||||
wei_k_c0_y_x_c1(k, c / InWeiVectorSize, y, x, c % InWeiVectorSize) =
|
||||
wei_k_c_y_x(k, c, y, x);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)();
|
||||
make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)();
|
||||
|
||||
DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) *
|
||||
in_n_c0_hi_wi_c1.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) *
|
||||
out_n_k0_ho_wo_k1.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
|
||||
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
|
||||
|
||||
const auto in_n_c0_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi));
|
||||
const auto wei_k_c0_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X));
|
||||
const auto out_n_k0_ho_wo_k1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
|
||||
|
||||
#if 1
|
||||
// cdata = 64, BlockSize = 64, 16x8x32x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 16;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
constexpr index_t EPerBlock = 1;
|
||||
|
||||
constexpr index_t KPerThread = KPerBlock;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
constexpr index_t EPerThread = EPerBlock;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>;
|
||||
using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
|
||||
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_W = 16;
|
||||
|
||||
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
|
||||
#else
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 16;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
constexpr index_t EPerBlock = 1;
|
||||
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
constexpr index_t EPerThread = EPerBlock;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
|
||||
using ABlockTransferThreadClusterLengths_E_K = Sequence<EPerBlock, 16>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
|
||||
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_W = K1;
|
||||
|
||||
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
|
||||
#endif
|
||||
|
||||
constexpr auto conv_driver =
|
||||
#if 0
|
||||
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
#else
|
||||
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
|
||||
#endif
|
||||
<BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TAcc,
|
||||
TOut,
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
EPerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
ABlockTransferSrcScalarPerVector_E,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
BThreadTransferSrcScalarPerVector_W,
|
||||
CThreadTransferDstScalarPerVector_W>{};
|
||||
|
||||
conv_driver.Run(wei_k_c0_y_x_desc,
|
||||
in_n_c0_hi_wi_desc,
|
||||
out_n_k0_ho_wo_k1_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()));
|
||||
|
||||
out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
|
||||
|
||||
auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
out_n_k_ho_wo(n, k, ho, wo) =
|
||||
out_n_k0_ho_wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Ho, Wo)();
|
||||
}
|
||||
@@ -0,0 +1,212 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
ck::ActivTypeEnum_t activ_type,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename MaxLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1(
|
||||
const InLengths& in_n_c0_hi_wi_c1_lengths,
|
||||
const WeiLengths& wei_k_c0_y_x_c1_lengths,
|
||||
const MaxLengths& max_n_k0_hx_wx_k1_lengths,
|
||||
const OutLengths& out_n_k0_ho_wo_k1_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c0_hi_wi_c1,
|
||||
const Tensor<TInWei>& wei_k_c0_y_x_c1,
|
||||
const Tensor<TOut>& bias_k0_k1,
|
||||
Tensor<TOut>& out_n_k0_ho_wo_k1,
|
||||
Tensor<TOut>& max_n_k0_hx_wx_k1,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = out_n_k0_ho_wo_k1_lengths[I0];
|
||||
const auto K0 = out_n_k0_ho_wo_k1_lengths[I1];
|
||||
const auto Ho = out_n_k0_ho_wo_k1_lengths[I2];
|
||||
const auto Wo = out_n_k0_ho_wo_k1_lengths[I3];
|
||||
const auto K1 = out_n_k0_ho_wo_k1_lengths[I4];
|
||||
|
||||
const auto C0 = in_n_c0_hi_wi_c1_lengths[I1];
|
||||
const auto Hi = in_n_c0_hi_wi_c1_lengths[I2];
|
||||
const auto Wi = in_n_c0_hi_wi_c1_lengths[I3];
|
||||
const auto C1 = in_n_c0_hi_wi_c1_lengths[I4];
|
||||
|
||||
const auto K = wei_k_c0_y_x_c1_lengths[I0];
|
||||
const auto Y = wei_k_c0_y_x_c1_lengths[I2];
|
||||
const auto X = wei_k_c0_y_x_c1_lengths[I3];
|
||||
|
||||
const auto Hx = max_n_k0_hx_wx_k1_lengths[I2];
|
||||
const auto Wx = max_n_k0_hx_wx_k1_lengths[I3];
|
||||
|
||||
DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) *
|
||||
in_n_c0_hi_wi_c1.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace());
|
||||
DeviceMem bias_k0_k1_device_buf(sizeof(TOut) * bias_k0_k1.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) *
|
||||
out_n_k0_ho_wo_k1.mDesc.GetElementSpace());
|
||||
DeviceMem max_n_k0_hx_wx_k1_device_buf(sizeof(TOut) *
|
||||
max_n_k0_hx_wx_k1.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
|
||||
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
|
||||
bias_k0_k1_device_buf.ToDevice(bias_k0_k1.mData.data());
|
||||
max_n_k0_hx_wx_k1_device_buf.ToDevice(max_n_k0_hx_wx_k1.mData.data());
|
||||
|
||||
constexpr index_t InWeiVectorSize = 8;
|
||||
|
||||
if(C1 % InWeiVectorSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize");
|
||||
}
|
||||
|
||||
#if 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 64;
|
||||
|
||||
constexpr index_t E1 = C0 * 9;
|
||||
constexpr index_t E2 = 1;
|
||||
constexpr index_t E1PerBlock = C0;
|
||||
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
constexpr index_t EPerThread = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>;
|
||||
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
|
||||
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_K = K1;
|
||||
#elif 1
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
|
||||
constexpr index_t E1 = 2 * 9;
|
||||
constexpr index_t E2 = 1;
|
||||
constexpr index_t K2 = 2;
|
||||
constexpr index_t E1PerBlock = 2;
|
||||
|
||||
constexpr index_t KPerThread = KPerBlock;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
constexpr index_t EPerThread = 1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, 1, E2>;
|
||||
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 =
|
||||
Sequence<1, E1PerBlock, 1, KPerBlock, 1>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_K = InWeiVectorSize;
|
||||
#endif
|
||||
|
||||
if(KPerThread % InWeiVectorSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize");
|
||||
}
|
||||
|
||||
const auto in_n_c0_hi_wi_c1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2));
|
||||
const auto wei_k_c0_y_x_c1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, E2));
|
||||
const auto max_n_k0_hx_wx_k1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hx, Wx, K1));
|
||||
const auto out_n_k0_ho_wo_k1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
|
||||
|
||||
constexpr auto conv_driver =
|
||||
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_maxpool<
|
||||
BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TAcc,
|
||||
TOut,
|
||||
E1,
|
||||
E2,
|
||||
K2,
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
E1PerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
|
||||
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
|
||||
ABlockTransferSrcScalarPerVector_E2,
|
||||
ABlockTransferDstScalarPerVector_E2,
|
||||
BThreadTransferSrcScalarPerVector_E2,
|
||||
CThreadTransferDstScalarPerVector_K,
|
||||
activ_type>{};
|
||||
|
||||
std::cerr << "conv_bias_activ_maxpool_input_"
|
||||
<< "n" << N << "c" << C0 << "h" << Hi << "w" << Wi << "c" << C1 << "_filter_k" << K
|
||||
<< "c" << C0 << "y" << Y << "x" << X << "c" << C1 << "_convout_n" << N << "k" << K0
|
||||
<< "h" << Ho << "w" << Wo << "k" << K1 << "_maxpoolout_n" << N << "k" << K0 << "h"
|
||||
<< Ho / 2 << "w" << Wo / 2 << "k" << K1 << std::endl;
|
||||
|
||||
for(int i = 0; i < 5; i++)
|
||||
{
|
||||
|
||||
const auto ave_time =
|
||||
conv_driver.Run(wei_k_c0_y_x_c1_desc,
|
||||
in_n_c0_hi_wi_c1_desc,
|
||||
out_n_k0_ho_wo_k1_desc,
|
||||
max_n_k0_hx_wx_k1_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(bias_k0_k1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(max_n_k0_hx_wx_k1_device_buf.GetDeviceBuffer()),
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
float perf = static_cast<float>(std::size_t(2) * N * K * Ho * Wo * C0 * C1 * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
|
||||
max_n_k0_hx_wx_k1_device_buf.FromDevice(max_n_k0_hx_wx_k1.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,565 @@
|
||||
#ifndef DRIVER_CONVOLUTION_ADD_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
|
||||
#define DRIVER_CONVOLUTION_ADD_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v3.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::index_t E1_,
|
||||
ck::index_t E2_,
|
||||
ck::index_t K2_,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t HoPerBlock,
|
||||
ck::index_t WoPerBlock,
|
||||
ck::index_t E1PerBlock,
|
||||
ck::index_t KPerThread,
|
||||
ck::index_t HoPerThread,
|
||||
ck::index_t WoPerThread,
|
||||
ck::index_t EPerThread,
|
||||
typename ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
|
||||
typename ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector_E2,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_E2,
|
||||
ck::index_t BThreadTransferSrcScalarPerVector_E2,
|
||||
ck::index_t CThreadTransferDstScalarPerVector_K,
|
||||
ck::ActivTypeEnum_t activ_type>
|
||||
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_add
|
||||
{
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Add,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ float Run(const ck::TensorDescriptor<Wei...>& wei_k_c0_y_x_c1_global_desc,
|
||||
const ck::TensorDescriptor<In...>& in_n_c0_hi_wi_c1_global_desc,
|
||||
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
const ck::TensorDescriptor<Add...>& add_n_k0_hox2_wox2_k1_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatC* __restrict__ p_bias_grid,
|
||||
FloatC* __restrict__ p_d_grid,
|
||||
const int nrepeat) const
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = in_n_c0_hi_wi_c1_global_desc.GetLength(I0);
|
||||
const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1);
|
||||
const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3);
|
||||
// const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4);
|
||||
|
||||
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
|
||||
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
|
||||
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
|
||||
|
||||
const auto Hox2 = add_n_k0_hox2_wox2_k1_global_desc.GetLength(I2);
|
||||
const auto Wox2 = add_n_k0_hox2_wox2_k1_global_desc.GetLength(I3);
|
||||
|
||||
const auto K = wei_k_c0_y_x_c1_global_desc.GetLength(I0);
|
||||
const auto Y = wei_k_c0_y_x_c1_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c0_y_x_c1_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
|
||||
const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{};
|
||||
const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{};
|
||||
|
||||
const auto OutRightPadH = Hop - Ho;
|
||||
const auto OutRightPadW = Wop - Wo;
|
||||
|
||||
const auto OutRightPadHx = Number<OutRightPadH * 2>{};
|
||||
const auto OutRightPadWx = Number<OutRightPadW * 2>{};
|
||||
#else
|
||||
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
|
||||
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
|
||||
|
||||
const auto OutRightPadH = Hop - Ho;
|
||||
const auto OutRightPadW = Wop - Wo;
|
||||
|
||||
const auto OutRightPadHx = OutRightPadH * 2;
|
||||
const auto OutRightPadWx = OutRightPadW * 2;
|
||||
#endif
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
|
||||
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
|
||||
|
||||
const auto E = C0 * Y * X;
|
||||
|
||||
constexpr auto E1 = Number<E1_>{};
|
||||
constexpr auto E2 = Number<E2_>{};
|
||||
constexpr auto K2 = Number<K2_>{};
|
||||
|
||||
const auto E0 = E / E1;
|
||||
|
||||
// weight tensor
|
||||
const auto a_e_k_e2_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X, E2)),
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_pass_through_transform(C0 * Y * X),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}));
|
||||
|
||||
const auto a_e0_e1_k_e2_grid_desc =
|
||||
transform_tensor_descriptor(a_e_k_e2_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
|
||||
make_pass_through_transform(K),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c0_hip_wip_e2_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)),
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C0),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto in_n_c0_y_ho_x_wo_e2_global_desc = transform_tensor_descriptor(
|
||||
in_n_c0_hip_wip_e2_global_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(C0),
|
||||
make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{}));
|
||||
|
||||
const auto in_e_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c0_y_ho_x_wo_e2_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C0, Y, X)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Hop),
|
||||
make_pass_through_transform(Wop),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(
|
||||
Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto b_e0_e1_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
|
||||
in_e_n_ho_wo_e2_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Hop),
|
||||
make_pass_through_transform(Wop),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}));
|
||||
|
||||
// output tensor
|
||||
const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, I0, OutRightPadH),
|
||||
make_pad_transform(Wo, I0, OutRightPadW)),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// add tensor
|
||||
const auto d_k_n_hopx2_wopx2_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hox2, Wox2, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pad_transform(Hox2, I0, OutRightPadHx),
|
||||
make_pad_transform(Wox2, I0, OutRightPadWx)),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
|
||||
|
||||
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
|
||||
(E1 % E1PerBlock) == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
|
||||
// hack to control index calculation when iterating over a_e0_e1_k_e2_global tensor
|
||||
constexpr auto a_e0_e1_k_e2_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global tensor
|
||||
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks =
|
||||
make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})
|
||||
);
|
||||
|
||||
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_k0_k1_n_h0_h1_h2_w0_w1_w2_global tensor
|
||||
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
// clang-format on
|
||||
|
||||
// GEMM
|
||||
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(a_e0_e1_k_e2_grid_desc),
|
||||
decltype(b_e0_e1_n_ho_wo_e2_grid_desc),
|
||||
decltype(c_k_n_hop_wop_grid_desc),
|
||||
decltype(d_k_n_hopx2_wopx2_grid_desc),
|
||||
E1,
|
||||
E2,
|
||||
K2,
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
E1PerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
|
||||
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
|
||||
Sequence<2, 3, 0, 1, 4>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
ABlockTransferSrcScalarPerVector_E2,
|
||||
ABlockTransferDstScalarPerVector_E2,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, // E0, E1, N, H0, H1, H2, W0, W1, W2, E2
|
||||
9,
|
||||
BThreadTransferSrcScalarPerVector_E2,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, // K0, K1, N, H0, H1, I2, H2, W0, W1, I2, W2
|
||||
1,
|
||||
CThreadTransferDstScalarPerVector_K,
|
||||
decltype(a_e0_e1_k_e2_global_step_hacks),
|
||||
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks),
|
||||
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
|
||||
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks),
|
||||
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
|
||||
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>;
|
||||
|
||||
const auto a_e0_e1_k0_k1_e2_grid_desc =
|
||||
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
|
||||
const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
|
||||
GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc);
|
||||
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
|
||||
GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc);
|
||||
const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc =
|
||||
GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd(
|
||||
d_k_n_hopx2_wopx2_grid_desc);
|
||||
|
||||
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
|
||||
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
|
||||
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
|
||||
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
|
||||
using DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 =
|
||||
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc);
|
||||
|
||||
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
|
||||
|
||||
const bool has_main_e0_block_loop = E0 > 1;
|
||||
|
||||
std::cerr << "has_main_e0_block_loop = " << has_main_e0_block_loop << std::endl;
|
||||
|
||||
const auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToKNHoWoBlockClusterAdaptor(c_k_n_hop_wop_grid_desc);
|
||||
|
||||
using CBlockIdToBlockClusterAdaptor_K_N_H_W =
|
||||
decltype(c_blockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
|
||||
if(has_main_e0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_dlops_v3_resize_add<
|
||||
GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
true,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_d_grid,
|
||||
a_e0_e1_k0_k1_e2_grid_desc,
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
|
||||
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
|
||||
c_blockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_dlops_v3_resize_add<
|
||||
GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
false,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_d_grid,
|
||||
a_e0_e1_k0_k1_e2_grid_desc,
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
|
||||
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
|
||||
c_blockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_e0_e1_k0_k1_e2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K0_K1_E2));
|
||||
DeviceMem b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf(
|
||||
sizeof(BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2));
|
||||
DeviceMem c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf(
|
||||
sizeof(CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2));
|
||||
DeviceMem d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf(
|
||||
sizeof(DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2));
|
||||
DeviceMem c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf(
|
||||
sizeof(CBlockIdToBlockClusterAdaptor_K_N_H_W));
|
||||
|
||||
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.ToDevice(&a_e0_e1_k0_k1_e2_grid_desc);
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.ToDevice(
|
||||
&b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.ToDevice(
|
||||
&c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
|
||||
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf.ToDevice(
|
||||
&d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc);
|
||||
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.ToDevice(
|
||||
&c_blockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
|
||||
if(has_main_e0_block_loop)
|
||||
{
|
||||
|
||||
const auto kernel = kernel_gemm_dlops_v3_resize_add<
|
||||
GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
true,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_d_grid,
|
||||
cast_pointer_to_constant_address_space(
|
||||
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_dlops_v3_resize_add<
|
||||
GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
false,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_d_grid,
|
||||
cast_pointer_to_constant_address_space(
|
||||
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
|
||||
{
|
||||
static_assert(a_e0_e1_k_e2_grid_desc.IsKnownAtCompileTime(), "");
|
||||
static_assert(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.IsKnownAtCompileTime(), "");
|
||||
static_assert(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc.IsKnownAtCompileTime(), "");
|
||||
static_assert(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.IsKnownAtCompileTime(), "");
|
||||
static_assert(c_blockid_to_k_n_h_w_block_cluster_adaptor.IsKnownAtCompileTime(), "");
|
||||
|
||||
const auto kernel = kernel_gemm_dlops_v3_resize_add<
|
||||
GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
has_main_e0_block_loop,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_d_grid);
|
||||
}
|
||||
#endif
|
||||
return ave_time;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
@@ -0,0 +1,500 @@
|
||||
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
|
||||
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v3.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::index_t E1_,
|
||||
ck::index_t E2_,
|
||||
ck::index_t K2_,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t HoPerBlock,
|
||||
ck::index_t WoPerBlock,
|
||||
ck::index_t E1PerBlock,
|
||||
ck::index_t KPerThread,
|
||||
ck::index_t HoPerThread,
|
||||
ck::index_t WoPerThread,
|
||||
ck::index_t EPerThread,
|
||||
typename ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
|
||||
typename ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector_E2,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_E2,
|
||||
ck::index_t BThreadTransferSrcScalarPerVector_E2,
|
||||
ck::index_t CThreadTransferDstScalarPerVector_K,
|
||||
ck::ActivTypeEnum_t activ_type>
|
||||
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_outpad
|
||||
{
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ float Run(const ck::TensorDescriptor<Wei...>& wei_k_c0_y_x_c1_global_desc,
|
||||
const ck::TensorDescriptor<In...>& in_n_c0_hi_wi_c1_global_desc,
|
||||
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatC* __restrict__ p_bias_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const int nrepeat) const
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = in_n_c0_hi_wi_c1_global_desc.GetLength(I0);
|
||||
const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1);
|
||||
const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3);
|
||||
// const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4);
|
||||
|
||||
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
|
||||
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
|
||||
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
|
||||
|
||||
const auto K = wei_k_c0_y_x_c1_global_desc.GetLength(I0);
|
||||
const auto Y = wei_k_c0_y_x_c1_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c0_y_x_c1_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
|
||||
const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{};
|
||||
const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{};
|
||||
#else
|
||||
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
|
||||
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
|
||||
#endif
|
||||
|
||||
const auto OutRightPadH = Hop - Ho;
|
||||
const auto OutRightPadW = Wop - Wo;
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
|
||||
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
|
||||
|
||||
const auto E = C0 * Y * X;
|
||||
|
||||
constexpr auto E1 = Number<E1_>{};
|
||||
constexpr auto E2 = Number<E2_>{};
|
||||
constexpr auto K2 = Number<K2_>{};
|
||||
|
||||
const auto E0 = E / E1;
|
||||
|
||||
// weight tensor
|
||||
const auto a_e_k_e2_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X, E2)),
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_pass_through_transform(C0 * Y * X),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}));
|
||||
|
||||
const auto a_e0_e1_k_e2_grid_desc =
|
||||
transform_tensor_descriptor(a_e_k_e2_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
|
||||
make_pass_through_transform(K),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c0_hip_wip_e2_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)),
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C0),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto in_n_c0_y_ho_x_wo_e2_global_desc = transform_tensor_descriptor(
|
||||
in_n_c0_hip_wip_e2_global_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(C0),
|
||||
make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{}));
|
||||
|
||||
const auto in_e_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c0_y_ho_x_wo_e2_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C0, Y, X)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Hop),
|
||||
make_pass_through_transform(Wop),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(
|
||||
Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto b_e0_e1_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
|
||||
in_e_n_ho_wo_e2_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Hop),
|
||||
make_pass_through_transform(Wop),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}));
|
||||
|
||||
// output tensor
|
||||
const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, I0, OutRightPadH),
|
||||
make_pad_transform(Wo, I0, OutRightPadW)),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
|
||||
|
||||
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
|
||||
(E1 % E1PerBlock) == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
|
||||
// hack to control index calculation when iterating over a_e0_e1_k_e2_global tensor
|
||||
constexpr auto a_e0_e1_k_e2_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global tensor
|
||||
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks =
|
||||
make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})
|
||||
);
|
||||
|
||||
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_k0_k1_n_h0_h1_h2_w0_w1_w2_global tensor
|
||||
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
// clang-format on
|
||||
|
||||
// GEMM
|
||||
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(a_e0_e1_k_e2_grid_desc),
|
||||
decltype(b_e0_e1_n_ho_wo_e2_grid_desc),
|
||||
decltype(c_k_n_hop_wop_grid_desc),
|
||||
decltype(c_k_n_hop_wop_grid_desc),
|
||||
E1,
|
||||
E2,
|
||||
K2,
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
E1PerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
|
||||
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
|
||||
Sequence<2, 3, 0, 1, 4>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
ABlockTransferSrcScalarPerVector_E2,
|
||||
ABlockTransferDstScalarPerVector_E2,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, // E0, E1, N, H0, H1, H2, W0, W1, W2, E2
|
||||
9,
|
||||
BThreadTransferSrcScalarPerVector_E2,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, // K0, K1, N, H0, H1, H2, W0, W1, W2
|
||||
1,
|
||||
CThreadTransferDstScalarPerVector_K,
|
||||
decltype(a_e0_e1_k_e2_global_step_hacks),
|
||||
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks),
|
||||
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
|
||||
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
|
||||
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
|
||||
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>;
|
||||
|
||||
const auto a_e0_e1_k0_k1_e2_grid_desc =
|
||||
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
|
||||
const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
|
||||
GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc);
|
||||
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
|
||||
GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc);
|
||||
|
||||
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
|
||||
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
|
||||
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
|
||||
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
|
||||
|
||||
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
|
||||
|
||||
const bool has_main_e0_block_loop = E0 > 1;
|
||||
|
||||
std::cerr << "has_main_e0_block_loop = " << has_main_e0_block_loop << std::endl;
|
||||
|
||||
const auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToKNHoWoBlockClusterAdaptor(c_k_n_hop_wop_grid_desc);
|
||||
|
||||
using CBlockIdToBlockClusterAdaptor_K_N_H_W =
|
||||
decltype(c_blockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
|
||||
if(has_main_e0_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
true,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_c_grid,
|
||||
a_e0_e1_k0_k1_e2_grid_desc,
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
|
||||
c_blockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
false,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_c_grid,
|
||||
a_e0_e1_k0_k1_e2_grid_desc,
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
|
||||
c_blockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_e0_e1_k0_k1_e2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K0_K1_E2));
|
||||
DeviceMem b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf(
|
||||
sizeof(BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2));
|
||||
DeviceMem c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf(
|
||||
sizeof(CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2));
|
||||
DeviceMem c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf(
|
||||
sizeof(CBlockIdToBlockClusterAdaptor_K_N_H_W));
|
||||
|
||||
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.ToDevice(&a_e0_e1_k0_k1_e2_grid_desc);
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.ToDevice(
|
||||
&b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.ToDevice(
|
||||
&c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
|
||||
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.ToDevice(
|
||||
&c_blockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
|
||||
if(has_main_e0_block_loop)
|
||||
{
|
||||
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
true,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(
|
||||
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
false,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(
|
||||
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
|
||||
{
|
||||
static_assert(a_e0_e1_k_e2_grid_desc.IsKnownAtCompileTime(), "");
|
||||
static_assert(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.IsKnownAtCompileTime(), "");
|
||||
static_assert(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.IsKnownAtCompileTime(), "");
|
||||
static_assert(c_blockid_to_k_n_h_w_block_cluster_adaptor.IsKnownAtCompileTime(), "");
|
||||
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
has_main_e0_block_loop,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_c_grid);
|
||||
}
|
||||
#endif
|
||||
return ave_time;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
@@ -1,349 +0,0 @@
|
||||
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
|
||||
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v2.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t HoPerBlock,
|
||||
ck::index_t WoPerBlock,
|
||||
ck::index_t EPerBlock,
|
||||
ck::index_t KPerThread,
|
||||
ck::index_t HoPerThread,
|
||||
ck::index_t WoPerThread,
|
||||
ck::index_t EPerThread,
|
||||
typename ABlockTransferThreadSliceLengths_E_K,
|
||||
typename ABlockTransferThreadClusterLengths_E_K,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector_E,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K,
|
||||
ck::index_t BThreadTransferSrcScalarPerVector_W,
|
||||
ck::index_t CThreadTransferDstScalarPerVector_W>
|
||||
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
{
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ void Run(const ck::TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const ck::TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const FloatAB* __restrict__ p_wei_global,
|
||||
const FloatAB* __restrict__ p_in_global,
|
||||
FloatC* __restrict__ p_out_global) const
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
|
||||
|
||||
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
|
||||
|
||||
const auto K = wei_k_c_y_x_global_desc.GetLength(I0);
|
||||
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
// weight tensor
|
||||
const auto wei_e_k_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Ho),
|
||||
make_pass_through_transform(Wo)),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_k_n_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Ho),
|
||||
make_pass_through_transform(Wo)),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto E = C * Y * X;
|
||||
|
||||
if(!((K % KPerBlock) == 0 && (Ho % HoPerBlock) == 0 && (Wo % WoPerBlock) == 0 &&
|
||||
(E % EPerBlock) == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// hack to control index calculation when iterating over a_k_m_global tensor
|
||||
constexpr auto a_e_k_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||
|
||||
constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{};
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||
// hack for NKHW format
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
#if 1
|
||||
// GEMM
|
||||
using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
EPerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
ABlockTransferSrcScalarPerVector_E,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 3, 1>,
|
||||
3,
|
||||
BThreadTransferSrcScalarPerVector_W,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<0, 2, 3, 1>,
|
||||
0,
|
||||
CThreadTransferDstScalarPerVector_W,
|
||||
decltype(a_e_k_global_step_hacks),
|
||||
decltype(b_e_n_ho_wo_global_step_hacks),
|
||||
decltype(c_k_n_ho_wo_global_tensor_step_hacks),
|
||||
decltype(a_e_k_global_move_slice_window_step_hack),
|
||||
decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>;
|
||||
|
||||
const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N;
|
||||
|
||||
const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1;
|
||||
|
||||
const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0;
|
||||
|
||||
index_t nrepeat = 100;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
std::cout << "has_main_k_block_loop: " << has_main_k_block_loop
|
||||
<< " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop
|
||||
<< std::endl;
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf =
|
||||
static_cast<float>(calculate_convolution_flops(in_n_c_hi_wi_global_desc,
|
||||
wei_k_c_y_x_global_desc,
|
||||
out_n_k0_ho_wo_k1_global_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#endif
|
||||
@@ -1,364 +0,0 @@
|
||||
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
|
||||
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v2.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t HoPerBlock,
|
||||
ck::index_t WoPerBlock,
|
||||
ck::index_t EPerBlock,
|
||||
ck::index_t KPerThread,
|
||||
ck::index_t HoPerThread,
|
||||
ck::index_t WoPerThread,
|
||||
ck::index_t EPerThread,
|
||||
typename ABlockTransferThreadSliceLengths_E_K,
|
||||
typename ABlockTransferThreadClusterLengths_E_K,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector_E,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K,
|
||||
ck::index_t BThreadTransferSrcScalarPerVector_W,
|
||||
ck::index_t CThreadTransferDstScalarPerVector_W>
|
||||
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
|
||||
{
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ void Run(const ck::TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const ck::TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const FloatAB* __restrict__ p_wei_global,
|
||||
const FloatAB* __restrict__ p_in_global,
|
||||
FloatC* __restrict__ p_out_global) const
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
|
||||
|
||||
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
|
||||
|
||||
const auto K = wei_k_c_y_x_global_desc.GetLength(I0);
|
||||
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
|
||||
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
|
||||
|
||||
const auto OutRightPadH = Hop - Ho;
|
||||
const auto OutRightPadW = Wop - Wo;
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
|
||||
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
|
||||
|
||||
std::cerr << "OutRightPadH = " << OutRightPadH << " OutRightPadW = " << OutRightPadW
|
||||
<< std::endl;
|
||||
std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW
|
||||
<< std::endl;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_e_k_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Hop),
|
||||
make_pass_through_transform(Wop)),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_k_n_hop_wop_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, 0, OutRightPadH),
|
||||
make_pad_transform(Wo, 0, OutRightPadW)),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto E = C * Y * X;
|
||||
|
||||
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
|
||||
|
||||
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
|
||||
(E % EPerBlock) == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// hack to control index calculation when iterating over a_k_m_global tensor
|
||||
constexpr auto a_e_k_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||
|
||||
constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{};
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||
// hack for NKHW format
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
// GEMM
|
||||
using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
EPerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
ABlockTransferSrcScalarPerVector_E,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 3, 1>,
|
||||
3,
|
||||
BThreadTransferSrcScalarPerVector_W,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<0, 2, 3, 1>,
|
||||
0,
|
||||
CThreadTransferDstScalarPerVector_W,
|
||||
decltype(a_e_k_global_step_hacks),
|
||||
decltype(b_e_n_ho_wo_global_step_hacks),
|
||||
decltype(c_k_n_ho_wo_global_tensor_step_hacks),
|
||||
decltype(a_e_k_global_move_slice_window_step_hack),
|
||||
decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>;
|
||||
|
||||
const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
|
||||
|
||||
const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1;
|
||||
|
||||
const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0;
|
||||
|
||||
index_t nrepeat = 100;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
std::cout << "has_main_k_block_loop: " << has_main_k_block_loop
|
||||
<< " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop
|
||||
<< std::endl;
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf =
|
||||
static_cast<float>(calculate_convolution_flops(in_n_c_hi_wi_global_desc,
|
||||
wei_k_c_y_x_global_desc,
|
||||
out_n_k0_ho_wo_k1_global_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif
|
||||
@@ -0,0 +1,569 @@
|
||||
#ifndef DRIVER_CONVOLUTION_MAXPOOL_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
|
||||
#define DRIVER_CONVOLUTION_MAXPOOL_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v3.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::index_t E1_,
|
||||
ck::index_t E2_,
|
||||
ck::index_t K2_,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t HoPerBlock,
|
||||
ck::index_t WoPerBlock,
|
||||
ck::index_t E1PerBlock,
|
||||
ck::index_t KPerThread,
|
||||
ck::index_t HoPerThread,
|
||||
ck::index_t WoPerThread,
|
||||
ck::index_t EPerThread,
|
||||
typename ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
|
||||
typename ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector_E2,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_E2,
|
||||
ck::index_t BThreadTransferSrcScalarPerVector_E2,
|
||||
ck::index_t CThreadTransferDstScalarPerVector_K,
|
||||
ck::ActivTypeEnum_t activ_type>
|
||||
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_maxpool
|
||||
{
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... MaxPool,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ float Run(const ck::TensorDescriptor<Wei...>& wei_k_c0_y_x_c1_global_desc,
|
||||
const ck::TensorDescriptor<In...>& in_n_c0_hi_wi_c1_global_desc,
|
||||
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
const ck::TensorDescriptor<MaxPool...>& max_n_k0_hx_wx_k1_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatC* __restrict__ p_bias_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatC* __restrict__ p_d_grid,
|
||||
const int nrepeat) const
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = in_n_c0_hi_wi_c1_global_desc.GetLength(I0);
|
||||
const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1);
|
||||
const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3);
|
||||
// const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4);
|
||||
|
||||
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
|
||||
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
|
||||
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
|
||||
|
||||
const auto Hx = max_n_k0_hx_wx_k1_global_desc.GetLength(I2);
|
||||
const auto Wx = max_n_k0_hx_wx_k1_global_desc.GetLength(I3);
|
||||
|
||||
const auto K = wei_k_c0_y_x_c1_global_desc.GetLength(I0);
|
||||
const auto Y = wei_k_c0_y_x_c1_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c0_y_x_c1_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
|
||||
const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{};
|
||||
const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{};
|
||||
|
||||
const auto OutRightPadH = Hop - Ho;
|
||||
const auto OutRightPadW = Wop - Wo;
|
||||
|
||||
const auto OutRightPadHx = Number<OutRightPadH / 2>{};
|
||||
const auto OutRightPadWx = Number<OutRightPadW / 2>{};
|
||||
#else
|
||||
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
|
||||
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
|
||||
|
||||
const auto OutRightPadH = Hop - Ho;
|
||||
const auto OutRightPadW = Wop - Wo;
|
||||
|
||||
const auto OutRightPadHx = OutRightPadH / 2;
|
||||
const auto OutRightPadWx = OutRightPadW / 2;
|
||||
#endif
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
|
||||
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
|
||||
|
||||
const auto E = C0 * Y * X;
|
||||
|
||||
constexpr auto E1 = Number<E1_>{};
|
||||
constexpr auto E2 = Number<E2_>{};
|
||||
constexpr auto K2 = Number<K2_>{};
|
||||
|
||||
const auto E0 = E / E1;
|
||||
|
||||
// weight tensor
|
||||
const auto a_e_k_e2_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X, E2)),
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_pass_through_transform(C0 * Y * X),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}));
|
||||
|
||||
const auto a_e0_e1_k_e2_grid_desc =
|
||||
transform_tensor_descriptor(a_e_k_e2_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
|
||||
make_pass_through_transform(K),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c0_hip_wip_e2_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, E2)),
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C0),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto in_n_c0_y_ho_x_wo_e2_global_desc = transform_tensor_descriptor(
|
||||
in_n_c0_hip_wip_e2_global_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(C0),
|
||||
make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{}));
|
||||
|
||||
const auto in_e_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c0_y_ho_x_wo_e2_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C0, Y, X)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Hop),
|
||||
make_pass_through_transform(Wop),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(
|
||||
Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto b_e0_e1_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
|
||||
in_e_n_ho_wo_e2_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Hop),
|
||||
make_pass_through_transform(Wop),
|
||||
make_pass_through_transform(E2)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}));
|
||||
|
||||
// output tensor
|
||||
const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, I0, OutRightPadH),
|
||||
make_pad_transform(Wo, I0, OutRightPadW)),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// max tensor
|
||||
const auto d_k_n_hx_wx_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hx, Wx, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pad_transform(Hx, I0, OutRightPadHx),
|
||||
make_pad_transform(Wx, I0, OutRightPadWx)),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
|
||||
|
||||
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
|
||||
(E1 % E1PerBlock) == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
|
||||
// hack to control index calculation when iterating over a_e0_e1_k_e2_global tensor
|
||||
constexpr auto a_e0_e1_k_e2_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global tensor
|
||||
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks =
|
||||
make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})
|
||||
);
|
||||
|
||||
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
// clang-format on
|
||||
|
||||
// GEMM
|
||||
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(a_e0_e1_k_e2_grid_desc),
|
||||
decltype(b_e0_e1_n_ho_wo_e2_grid_desc),
|
||||
decltype(c_k_n_hop_wop_grid_desc),
|
||||
decltype(d_k_n_hx_wx_grid_desc),
|
||||
E1,
|
||||
E2,
|
||||
K2,
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
E1PerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
|
||||
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
|
||||
Sequence<2, 3, 0, 1, 4>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
ABlockTransferSrcScalarPerVector_E2,
|
||||
ABlockTransferDstScalarPerVector_E2,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, // E0, E1, N, H0, H1, H2, W0, W1, W2, E2
|
||||
9,
|
||||
BThreadTransferSrcScalarPerVector_E2,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused
|
||||
// with MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, // K0, K1, N, H0, H1, I2, H2, W0, W1, I2, W2
|
||||
1,
|
||||
CThreadTransferDstScalarPerVector_K,
|
||||
decltype(a_e0_e1_k_e2_global_step_hacks),
|
||||
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks),
|
||||
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
|
||||
decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks),
|
||||
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
|
||||
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>;
|
||||
|
||||
const auto a_e0_e1_k0_k1_e2_grid_desc =
|
||||
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
|
||||
const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
|
||||
GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc);
|
||||
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
|
||||
GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc);
|
||||
const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc =
|
||||
GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(d_k_n_hx_wx_grid_desc);
|
||||
|
||||
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
|
||||
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
|
||||
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
|
||||
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
|
||||
using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx = decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc);
|
||||
|
||||
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
|
||||
|
||||
const bool has_main_e0_block_loop = E0 > 1;
|
||||
|
||||
std::cerr << "has_main_e0_block_loop = " << has_main_e0_block_loop << std::endl;
|
||||
|
||||
const auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToKNHoWoBlockClusterAdaptor(c_k_n_hop_wop_grid_desc);
|
||||
|
||||
using CBlockIdToBlockClusterAdaptor_K_N_H_W =
|
||||
decltype(c_blockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
|
||||
if(has_main_e0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_dlops_v3_maxpool<
|
||||
GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
true,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_c_grid,
|
||||
p_d_grid,
|
||||
a_e0_e1_k0_k1_e2_grid_desc,
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
|
||||
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
|
||||
c_blockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_dlops_v3_maxpool<
|
||||
GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
false,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_c_grid,
|
||||
p_d_grid,
|
||||
a_e0_e1_k0_k1_e2_grid_desc,
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
|
||||
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
|
||||
c_blockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_e0_e1_k0_k1_e2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K0_K1_E2));
|
||||
DeviceMem b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf(
|
||||
sizeof(BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2));
|
||||
DeviceMem c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf(
|
||||
sizeof(CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2));
|
||||
DeviceMem d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc_dev_buf(
|
||||
sizeof(DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx));
|
||||
DeviceMem c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf(
|
||||
sizeof(CBlockIdToBlockClusterAdaptor_K_N_H_W));
|
||||
|
||||
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.ToDevice(&a_e0_e1_k0_k1_e2_grid_desc);
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.ToDevice(
|
||||
&b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.ToDevice(
|
||||
&c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
|
||||
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc_dev_buf.ToDevice(
|
||||
&d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc);
|
||||
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.ToDevice(
|
||||
&c_blockid_to_k_n_h_w_block_cluster_adaptor);
|
||||
|
||||
if(has_main_e0_block_loop)
|
||||
{
|
||||
|
||||
const auto kernel = kernel_gemm_dlops_v3_maxpool<
|
||||
GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
true,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_c_grid,
|
||||
p_d_grid,
|
||||
cast_pointer_to_constant_address_space(
|
||||
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
const auto kernel = kernel_gemm_dlops_v3_maxpool<
|
||||
GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
false,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_c_grid,
|
||||
p_d_grid,
|
||||
cast_pointer_to_constant_address_space(
|
||||
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
|
||||
{
|
||||
static_assert(a_e0_e1_k_e2_grid_desc.IsKnownAtCompileTime(), "");
|
||||
static_assert(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.IsKnownAtCompileTime(), "");
|
||||
static_assert(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.IsKnownAtCompileTime(), "");
|
||||
static_assert(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.IsKnownAtCompileTime(), "");
|
||||
static_assert(c_blockid_to_k_n_h_w_block_cluster_adaptor.IsKnownAtCompileTime(), "");
|
||||
|
||||
const auto kernel = kernel_gemm_dlops_v3_maxpool<
|
||||
GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
|
||||
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
|
||||
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
|
||||
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx>,
|
||||
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
|
||||
has_main_e0_block_loop,
|
||||
activ_type>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_bias_grid,
|
||||
p_c_grid,
|
||||
p_d_grid);
|
||||
}
|
||||
#endif
|
||||
return ave_time;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
Reference in New Issue
Block a user