mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
Dynamic tensor descriptor (#24)
* support dynamic tensor descriptor * use buffer load OOB feature for padding case * add navi support * add int8x4 inference kernel Co-authored-by: Chao Liu <chao@ixt-rack-81.local.lan> Co-authored-by: Jing Zhang <jizhan@amd.com>
This commit is contained in:
@@ -4,10 +4,7 @@
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "print_array.hpp"
|
||||
#include "print_sequence.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
@@ -54,10 +51,10 @@ int main(int argc, char* argv[])
|
||||
#elif 0
|
||||
// 3x3, 28x28
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
@@ -156,13 +153,13 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<2, 2>;
|
||||
using RightPads = Sequence<2, 2>;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 1x7 filter, 0x3 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 7;
|
||||
|
||||
@@ -197,7 +194,7 @@ int main(int argc, char* argv[])
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<2, 2>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
@@ -211,11 +208,11 @@ int main(int argc, char* argv[])
|
||||
ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
|
||||
ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
|
||||
ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
|
||||
print_sequence("LeftPads", LeftPads{});
|
||||
print_sequence("LeftPads", LeftPads{});
|
||||
print_sequence("RightPads", RightPads{});
|
||||
print_sequence("ConvStrides", ConvStrides{});
|
||||
print_sequence("ConvDilations", ConvDilations{});
|
||||
print_array("LeftPads", LeftPads{});
|
||||
print_array("LeftPads", LeftPads{});
|
||||
print_array("RightPads", RightPads{});
|
||||
print_array("ConvStrides", ConvStrides{});
|
||||
print_array("ConvDilations", ConvDilations{});
|
||||
|
||||
Tensor<float> in_nchw_device(make_HostTensorDescriptor(in_nchw_desc));
|
||||
Tensor<float> in_nchw_host(make_HostTensorDescriptor(in_nchw_desc));
|
||||
@@ -248,7 +245,7 @@ int main(int argc, char* argv[])
|
||||
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
#elif 1
|
||||
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
|
||||
#elif 1
|
||||
device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
conv_bwd_data_driver.cpp
|
||||
@@ -5,27 +5,29 @@
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "print_array.hpp"
|
||||
#include "print_sequence.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
#if 0
|
||||
// 1x1, 17x17
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 1080;
|
||||
constexpr index_t WI = 1920;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
@@ -35,6 +37,135 @@ int main(int argc, char* argv[])
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 540;
|
||||
constexpr index_t WI = 960;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 270;
|
||||
constexpr index_t WI = 480;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 1080;
|
||||
constexpr index_t WI = 1920;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 1;
|
||||
constexpr index_t HI = 1024;
|
||||
constexpr index_t WI = 2048;
|
||||
constexpr index_t K = 4;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 540;
|
||||
constexpr index_t WI = 960;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 270;
|
||||
constexpr index_t WI = 480;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 3x3, 36x36, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 37;
|
||||
constexpr index_t WI = 37;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 35x35, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 71x71
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 71;
|
||||
constexpr index_t WI = 71;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 1
|
||||
// 1x1, 8x8
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1536;
|
||||
@@ -70,7 +201,7 @@ int main(int argc, char* argv[])
|
||||
constexpr index_t C = 96;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t K = 96;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
@@ -94,7 +225,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 7x1, 17x17
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
@@ -109,7 +240,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 1x7, 17x17
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
@@ -141,12 +272,11 @@ int main(int argc, char* argv[])
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 147x147
|
||||
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 32;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 147;
|
||||
constexpr index_t WI = 147;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
@@ -157,7 +287,6 @@ int main(int argc, char* argv[])
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 3x3, 149x149
|
||||
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 32;
|
||||
constexpr index_t HI = 149;
|
||||
@@ -201,7 +330,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 3x3, 35x35, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 288;
|
||||
@@ -244,21 +373,6 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 0>;
|
||||
using RightPads = Sequence<1, 0>;
|
||||
#elif 0
|
||||
// 3x1, 8x8
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 448;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 0>;
|
||||
using RightPads = Sequence<1, 0>;
|
||||
#elif 0
|
||||
@@ -278,7 +392,6 @@ int main(int argc, char* argv[])
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 7x1, 73x73
|
||||
// v44@v100 xx.xx%, cudnn@v100 xx.xx%
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 73;
|
||||
@@ -352,7 +465,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 3x3, 28x28
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
@@ -382,7 +495,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 1x1, 56x56, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
@@ -442,7 +555,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 1x1, 7x7
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
@@ -472,7 +585,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 1x1, 56x56
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
@@ -487,7 +600,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 3x3, 56x56
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
@@ -512,17 +625,26 @@ int main(int argc, char* argv[])
|
||||
ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
|
||||
ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
|
||||
ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
|
||||
print_sequence("LeftPads", LeftPads{});
|
||||
print_sequence("RightPads", RightPads{});
|
||||
print_sequence("ConvStrides", ConvStrides{});
|
||||
print_sequence("ConvDilations", ConvDilations{});
|
||||
print_array("LeftPads", to_multi_index(LeftPads{}));
|
||||
print_array("RightPads", to_multi_index(RightPads{}));
|
||||
print_array("ConvStrides", to_multi_index(ConvStrides{}));
|
||||
print_array("ConvDilations", to_multi_index(ConvDilations{}));
|
||||
|
||||
#if 1
|
||||
using in_data_t = float;
|
||||
using out_data_t = float;
|
||||
#else
|
||||
using in_data_t = half_float::half;
|
||||
using out_data_t = half_float::half;
|
||||
#if 0
|
||||
using in_data_t = float;
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 0
|
||||
using in_data_t = float;
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = int8_t;
|
||||
#elif 1
|
||||
using in_data_t = int8_t;
|
||||
constexpr index_t in_vector_size = 4;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
Tensor<in_data_t> in_nchw(make_HostTensorDescriptor(in_nchw_desc));
|
||||
@@ -532,14 +654,15 @@ int main(int argc, char* argv[])
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
if(argc != 3)
|
||||
if(argc != 4)
|
||||
{
|
||||
printf("arg1: do_verification, arg2: nrepeat\n");
|
||||
printf("arg1: do_verification, arg2: do_log, arg3: nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
bool do_verification = atoi(argv[1]);
|
||||
index_t nrepeat = atoi(argv[2]);
|
||||
bool do_log = atoi(argv[2]);
|
||||
index_t nrepeat = atoi(argv[3]);
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
@@ -548,9 +671,9 @@ int main(int argc, char* argv[])
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#elif 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_3{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
#elif 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread);
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#elif 1
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
@@ -565,59 +688,112 @@ int main(int argc, char* argv[])
|
||||
#endif
|
||||
}
|
||||
|
||||
#if 1
|
||||
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#if 0
|
||||
device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
out_data_t>
|
||||
|
||||
(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 1
|
||||
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
out_data_t>
|
||||
|
||||
(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 1
|
||||
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
#if 0
|
||||
if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 &&
|
||||
ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1)
|
||||
{
|
||||
host_winograd_3x3_convolution(
|
||||
in_nchw, wei_kcyx, out_nkhw_host, LeftPads{}, RightPads{});
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
host_direct_convolution(in_nchw,
|
||||
wei_kcyx,
|
||||
out_nkhw_host,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{});
|
||||
}
|
||||
host_direct_convolution(in_nchw,
|
||||
wei_kcyx,
|
||||
out_nkhw_host,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{});
|
||||
|
||||
check_error(out_nkhw_host, out_nkhw_device);
|
||||
|
||||
#if 0
|
||||
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
|
||||
#endif
|
||||
if(do_log)
|
||||
{
|
||||
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
conv_driver.cpp
|
||||
Reference in New Issue
Block a user