DL GEMM fp32/fp16/int8 (#41)

* add threadwise copy the copy a tensor in one copy, added kpack to DL GEMM

* add kpack into fwd v4r5 nchw fp32

[ROCm/composable_kernel commit: b8b2d0a6d1]
This commit is contained in:
Chao Liu
2021-07-04 22:50:29 -05:00
committed by GitHub
parent 892c52c2ed
commit 0d7baf0e50
21 changed files with 4508 additions and 270 deletions

View File

@@ -14,7 +14,9 @@
#include "device_tensor.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp"
@@ -24,23 +26,27 @@
#define USE_DYNAMIC_MODE 1
#define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4_NHWC 0
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V4R5_NCHW 0
#define USE_CONV_FWD_V4R5R2_NCHW 1
#define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4_XDL_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NHWC 0
#define USE_CONV_FWD_V4R4R3_XDL_NHWC 1
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1
#define USE_CONV_FWD_V4R4R3_XDL_NHWC 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
enum ConvForwardAlgo
{
V4R4NCHW, // 0
V4R4NHWC, // 1
V4R5NCHW, // 2
V5R1NCHW, // 3
V4R4XDLNCHW, // 4
V4R4R2XDLNHWC, // 5
V4R4R3XDLNHWC, // 6
V4R4R4XDLNHWC // 7
V4R4R2NHWC, // 2
V4R5NCHW, // 3
V4R5R2NCHW, // 4
V5R1NCHW, // 5
V4R4XDLNCHW, // 6
V4R4R2XDLNHWC, // 7
V4R4R3XDLNHWC, // 8
V4R4R4XDLNHWC // 9
};
int main(int argc, char* argv[])
@@ -132,21 +138,18 @@ int main(int argc, char* argv[])
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
#endif
#if 0
constexpr index_t in_vector_size = 1;
using in_data_t = float;
using acc_data_t = float;
using out_data_t = float;
#if 1
using in_data_t = float;
using acc_data_t = float;
using out_data_t = float;
#elif 1
constexpr index_t in_vector_size = 1;
using in_data_t = half_t;
using acc_data_t = float;
using out_data_t = half_t;
using in_data_t = half_t;
using acc_data_t = float;
using out_data_t = half_t;
#elif 1
constexpr index_t in_vector_size = 16;
using in_data_t = int8_t;
using acc_data_t = int32_t;
using out_data_t = int8_t;
using in_data_t = int8_t;
using acc_data_t = int32_t;
using out_data_t = int8_t;
#endif
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
@@ -348,6 +351,33 @@ int main(int argc, char* argv[])
}
#endif
#if USE_CONV_FWD_V4R4R2_NHWC
if(algo == ConvForwardAlgo::V4R4R2NHWC)
{
if(layout != ConvTensorLayout::NHWC)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nhwc();
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk<in_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei,
out_device,
nrepeat);
}
#endif
#if USE_CONV_FWD_V4R5_NCHW
if(algo == ConvForwardAlgo::V4R5NCHW)
{
@@ -374,6 +404,33 @@ int main(int argc, char* argv[])
}
#endif
#if USE_CONV_FWD_V4R5R2_NCHW
if(algo == ConvForwardAlgo::V4R5R2NCHW)
{
if(layout != ConvTensorLayout::NCHW)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nchw();
device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw<in_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei,
out_device,
nrepeat);
}
#endif
#if USE_CONV_FWD_V5R1_NCHW
if(algo == ConvForwardAlgo::V5R1NCHW)
{
@@ -385,7 +442,7 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw();
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw<in_data_t,
in_vector_size,
16,
acc_data_t,
out_data_t>(tmp[I0],
tmp[I1],
@@ -525,10 +582,10 @@ int main(int argc, char* argv[])
#if 0
if(do_log)
{
LogRange(std::cout << "in : ", in.mData, ",") << std::endl;
LogRange(std::cout << "wei: ", wei.mData, ",") << std::endl;
LogRange(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
LogRange(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "wei: ", wei.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
}
#endif
}