mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
Add xdlops v4r4r4 into online compilation (#48)
* init for v4r4 xdlops olc * refactor wrap * init impl of v4r4 nchw xdlops olc * tuning * test perf * fixed v4r4 nhwc * tuned v4r4 nhwc * use gridwise_gemm_xdlops_v2r3 * swap a/b * add pointer support into offline v2r3 * debugging v4r4r4 transform for olc * change timer of olc * refactor v4r4 xdlops nchw olc * remove transform fun in v4r4 xdlops nhwc olc Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -12,11 +12,17 @@
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
|
||||
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
#define USE_CONV_FWD_V4R4_NCHW 1
|
||||
#define USE_CONV_FWD_V4R5_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4_XDLOPS_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4_XDLOPS_NHWC 1
|
||||
|
||||
#include "conv_tunables.hpp"
|
||||
#include "handle.hpp"
|
||||
@@ -24,10 +30,10 @@
|
||||
|
||||
enum ConvForwardAlgo
|
||||
{
|
||||
V4R4NCHW,
|
||||
V4R4NHWC,
|
||||
V4R5NCHW,
|
||||
V5R1NCHW
|
||||
V4R4NCHW, // 0
|
||||
V4R5NCHW, // 1
|
||||
V4R4XDLNCHW, // 2
|
||||
V4R4XDLNHWC // 3
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -105,14 +111,16 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
case ConvTensorLayout::NCHW:
|
||||
// NCHW
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(Wi);
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(Wi);
|
||||
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(X);
|
||||
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
||||
@@ -120,14 +128,16 @@ int main(int argc, char* argv[])
|
||||
break;
|
||||
case ConvTensorLayout::NHWC:
|
||||
// NHWC
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(X);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Wo);
|
||||
@@ -271,10 +281,80 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R4_XDLOPS_NCHW
|
||||
if(algo == ConvForwardAlgo::V4R4XDLNCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* tunable =
|
||||
&default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw;
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_olc<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
handle,
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
tunable,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R4_XDLOPS_NHWC
|
||||
if(algo == ConvForwardAlgo::V4R4XDLNHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* tunable =
|
||||
&default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk;
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_olc<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
handle,
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
tunable,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_direct_convolution(
|
||||
in, wei, out_host, conv_strides, conv_dilations, in_left_pads, in_right_pads);
|
||||
host_direct_convolution(in,
|
||||
wei,
|
||||
out_host,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
layout);
|
||||
|
||||
check_error(out_host, out_device);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user