mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
Restructure gridwise and blockwise GEMM, add tensor contraction and FWD-v4r5 (#36)
* experimenting magic number division
* overhauling fwd-v4r4 to clearly reflect transformation graph
* added fwd-v4r5
* bug fix for make_dynamic_naive_tensor_descriptor_aligned_v2
* bug fix and added sanity-check in transform_dynamic_tensor_descriptor
* added conv_driver_v2
[ROCm/composable_kernel commit: 30072aec37]
This commit is contained in:
@@ -14,19 +14,41 @@
|
||||
#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
if(argc != 5)
|
||||
{
|
||||
printf("arg1: do_verification, arg2: do_log, arg3: init_method, arg4: nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const bool do_verification = atoi(argv[1]);
|
||||
const int init_method = atoi(argv[2]);
|
||||
const bool do_log = atoi(argv[3]);
|
||||
const int nrepeat = atoi(argv[4]);
|
||||
|
||||
#if 0
|
||||
constexpr index_t N = 8;
|
||||
constexpr index_t C = 8;
|
||||
constexpr index_t Hi = 4;
|
||||
constexpr index_t Wi = 8;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 1080;
|
||||
constexpr index_t WI = 1920;
|
||||
constexpr index_t Hi = 540;
|
||||
constexpr index_t Wi = 960;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -34,13 +56,13 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 540;
|
||||
constexpr index_t WI = 960;
|
||||
constexpr index_t Hi = 270;
|
||||
constexpr index_t Wi = 480;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -48,27 +70,13 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 270;
|
||||
constexpr index_t WI = 480;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 1080;
|
||||
constexpr index_t WI = 1920;
|
||||
constexpr index_t Hi = 1080;
|
||||
constexpr index_t Wi = 1920;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -76,13 +84,13 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 1;
|
||||
constexpr index_t HI = 1024;
|
||||
constexpr index_t WI = 2048;
|
||||
constexpr index_t Hi = 1024;
|
||||
constexpr index_t Wi = 2048;
|
||||
constexpr index_t K = 4;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -90,13 +98,13 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 540;
|
||||
constexpr index_t WI = 960;
|
||||
constexpr index_t Hi = 540;
|
||||
constexpr index_t Wi = 960;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -104,13 +112,13 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 270;
|
||||
constexpr index_t WI = 480;
|
||||
constexpr index_t Hi = 270;
|
||||
constexpr index_t Wi = 480;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -118,14 +126,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 3x3, 36x36, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 37;
|
||||
constexpr index_t WI = 37;
|
||||
constexpr index_t Hi = 37;
|
||||
constexpr index_t Wi = 37;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -133,14 +141,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 35x35, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t Hi = 35;
|
||||
constexpr index_t Wi = 35;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -148,14 +156,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 71x71
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 71;
|
||||
constexpr index_t WI = 71;
|
||||
constexpr index_t Hi = 71;
|
||||
constexpr index_t Wi = 71;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -163,14 +171,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 1
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 1x1, 8x8
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1536;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t Hi = 8;
|
||||
constexpr index_t Wi = 8;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -178,14 +186,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 73x73
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 160;
|
||||
constexpr index_t HI = 73;
|
||||
constexpr index_t WI = 73;
|
||||
constexpr index_t Hi = 73;
|
||||
constexpr index_t Wi = 73;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -193,14 +201,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 35x35
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 96;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t Hi = 35;
|
||||
constexpr index_t Wi = 35;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -208,14 +216,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 1
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 3x3, 71x71
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 71;
|
||||
constexpr index_t WI = 71;
|
||||
constexpr index_t Hi = 71;
|
||||
constexpr index_t Wi = 71;
|
||||
constexpr index_t K = 192;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -223,14 +231,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 7x1, 17x17
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t Hi = 17;
|
||||
constexpr index_t Wi = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 1;
|
||||
@@ -238,14 +246,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
#elif 0
|
||||
using InLeftPads = Sequence<3, 0>;
|
||||
using InRightPads = Sequence<3, 0>;
|
||||
#elif 1
|
||||
// 1x7, 17x17
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t Hi = 17;
|
||||
constexpr index_t Wi = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 7;
|
||||
@@ -253,14 +261,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 3>;
|
||||
using RightPads = Sequence<0, 3>;
|
||||
using InLeftPads = Sequence<0, 3>;
|
||||
using InRightPads = Sequence<0, 3>;
|
||||
#elif 0
|
||||
// 3x3, 299x299 stride=2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 3;
|
||||
constexpr index_t HI = 299;
|
||||
constexpr index_t WI = 299;
|
||||
constexpr index_t Hi = 299;
|
||||
constexpr index_t Wi = 299;
|
||||
constexpr index_t K = 32;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -268,14 +276,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 147x147
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 147;
|
||||
constexpr index_t WI = 147;
|
||||
constexpr index_t Hi = 147;
|
||||
constexpr index_t Wi = 147;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -283,14 +291,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 3x3, 149x149
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 32;
|
||||
constexpr index_t HI = 149;
|
||||
constexpr index_t WI = 149;
|
||||
constexpr index_t Hi = 149;
|
||||
constexpr index_t Wi = 149;
|
||||
constexpr index_t K = 32;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -298,14 +306,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 17x17, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t Hi = 17;
|
||||
constexpr index_t Wi = 17;
|
||||
constexpr index_t K = 192;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -313,14 +321,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 35x35
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 384;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t Hi = 35;
|
||||
constexpr index_t Wi = 35;
|
||||
constexpr index_t K = 96;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -328,14 +336,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 35x35, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 288;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t Hi = 35;
|
||||
constexpr index_t Wi = 35;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -343,14 +351,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x3, 8x8
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 384;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t Hi = 8;
|
||||
constexpr index_t Wi = 8;
|
||||
constexpr index_t K = 448;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 3;
|
||||
@@ -358,14 +366,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 1>;
|
||||
using RightPads = Sequence<0, 1>;
|
||||
using InLeftPads = Sequence<0, 1>;
|
||||
using InRightPads = Sequence<0, 1>;
|
||||
#elif 0
|
||||
// 3x1, 8x8
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 448;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t Hi = 8;
|
||||
constexpr index_t Wi = 8;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 1;
|
||||
@@ -373,14 +381,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 0>;
|
||||
using RightPads = Sequence<1, 0>;
|
||||
using InLeftPads = Sequence<1, 0>;
|
||||
using InRightPads = Sequence<1, 0>;
|
||||
#elif 0
|
||||
// 3x3, 147x147
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 147;
|
||||
constexpr index_t WI = 147;
|
||||
constexpr index_t Hi = 147;
|
||||
constexpr index_t Wi = 147;
|
||||
constexpr index_t K = 96;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -388,14 +396,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 7x1, 73x73
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 73;
|
||||
constexpr index_t WI = 73;
|
||||
constexpr index_t Hi = 73;
|
||||
constexpr index_t Wi = 73;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 1;
|
||||
@@ -403,14 +411,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
using InLeftPads = Sequence<3, 0>;
|
||||
using InRightPads = Sequence<3, 0>;
|
||||
#elif 0
|
||||
// 3x3, 73x73
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 73;
|
||||
constexpr index_t WI = 73;
|
||||
constexpr index_t Hi = 73;
|
||||
constexpr index_t Wi = 73;
|
||||
constexpr index_t K = 96;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -418,14 +426,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 14x14, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t Hi = 14;
|
||||
constexpr index_t Wi = 14;
|
||||
constexpr index_t K = 2048;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -433,14 +441,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 14x14
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t Hi = 14;
|
||||
constexpr index_t Wi = 14;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -448,14 +456,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 14x14, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t Hi = 14;
|
||||
constexpr index_t Wi = 14;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -463,14 +471,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
// 3x3, 28x28
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t Hi = 28;
|
||||
constexpr index_t Wi = 28;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -478,14 +486,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 1
|
||||
// 3x3, 14x14
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t Hi = 14;
|
||||
constexpr index_t Wi = 14;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -493,14 +501,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 1x1, 56x56, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t Hi = 56;
|
||||
constexpr index_t Wi = 56;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -508,14 +516,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 7x7, 230x230 stride=2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 3;
|
||||
constexpr index_t HI = 230;
|
||||
constexpr index_t WI = 230;
|
||||
constexpr index_t Hi = 230;
|
||||
constexpr index_t Wi = 230;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 7;
|
||||
@@ -523,14 +531,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 28x28, stride = 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t Hi = 28;
|
||||
constexpr index_t Wi = 28;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -538,14 +546,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 28x28, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t Hi = 28;
|
||||
constexpr index_t Wi = 28;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -553,14 +561,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
// 1x1, 7x7
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t Hi = 7;
|
||||
constexpr index_t Wi = 7;
|
||||
constexpr index_t K = 2048;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -568,14 +576,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 7x7
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t Hi = 7;
|
||||
constexpr index_t Wi = 7;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -583,14 +591,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 1x1, 56x56
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t Hi = 56;
|
||||
constexpr index_t Wi = 56;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -598,14 +606,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 56x56
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t Hi = 56;
|
||||
constexpr index_t Wi = 56;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -613,82 +621,86 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#endif
|
||||
|
||||
auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence<N, C, HI, WI>{});
|
||||
auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence<K, C, Y, X>{});
|
||||
auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor(
|
||||
in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, LeftPads{}, RightPads{});
|
||||
constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1;
|
||||
constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1;
|
||||
|
||||
ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
|
||||
ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
|
||||
ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
|
||||
print_array("LeftPads", to_multi_index(LeftPads{}));
|
||||
print_array("RightPads", to_multi_index(RightPads{}));
|
||||
print_array("ConvStrides", to_multi_index(ConvStrides{}));
|
||||
print_array("ConvDilations", to_multi_index(ConvDilations{}));
|
||||
constexpr index_t Ho = (Hi + InLeftPads{}[0] + InRightPads{}[0] - YEff) / ConvStrides{}[0] + 1;
|
||||
constexpr index_t Wo = (Wi + InLeftPads{}[1] + InRightPads{}[1] - XEff) / ConvStrides{}[1] + 1;
|
||||
|
||||
#if 1
|
||||
using in_data_t = float;
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using in_data_t = typename vector_type<float, in_vector_size>::type;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 0
|
||||
using in_data_t = float;
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using in_data_t = typename vector_type<float, in_vector_size>::type;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = int8_t;
|
||||
#elif 1
|
||||
using in_data_t = int8_t;
|
||||
constexpr index_t in_vector_size = 16;
|
||||
using in_data_t = typename vector_type<int8_t, in_vector_size>::type;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
Tensor<in_data_t> in_nchw(make_HostTensorDescriptor(in_nchw_desc));
|
||||
Tensor<in_data_t> wei_kcyx(make_HostTensorDescriptor(wei_kcyx_desc));
|
||||
Tensor<out_data_t> out_nkhw_host(make_HostTensorDescriptor(out_nkhw_desc));
|
||||
Tensor<out_data_t> out_nkhw_device(make_HostTensorDescriptor(out_nkhw_desc));
|
||||
Tensor<in_data_t> in_nchw(HostTensorDescriptor(std::initializer_list<index_t>{N, C, Hi, Wi}));
|
||||
Tensor<in_data_t> wei_kcyx(HostTensorDescriptor(std::initializer_list<index_t>{K, C, Y, X}));
|
||||
Tensor<out_data_t> out_nkhw_host(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{N, K, Ho, Wo}));
|
||||
Tensor<out_data_t> out_nkhw_device(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{N, K, Ho, Wo}));
|
||||
|
||||
ostream_HostTensorDescriptor(in_nchw.mDesc, std::cout << "in_nchw_desc: ");
|
||||
ostream_HostTensorDescriptor(wei_kcyx.mDesc, std::cout << "wei_kcyx_desc: ");
|
||||
ostream_HostTensorDescriptor(out_nkhw_host.mDesc, std::cout << "out_nkhw_desc: ");
|
||||
|
||||
print_array("InLeftPads", InLeftPads{});
|
||||
print_array("InRightPads", InRightPads{});
|
||||
print_array("ConvStrides", ConvStrides{});
|
||||
print_array("ConvDilations", ConvDilations{});
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
if(argc != 4)
|
||||
{
|
||||
printf("arg1: do_verification, arg2: do_log, arg3: nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
bool do_verification = atoi(argv[1]);
|
||||
bool do_log = atoi(argv[2]);
|
||||
index_t nrepeat = atoi(argv[3]);
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
#if 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#elif 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
#elif 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#elif 1
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
#elif 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 1:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei_kcyx.GenerateTensorValue(gen_wei, num_thread);
|
||||
#endif
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei_kcyx.GenerateTensorValue(gen_wei, num_thread);
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
constexpr auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence<N, C, Hi, Wi>{});
|
||||
constexpr auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence<K, C, Y, X>{});
|
||||
constexpr auto out_nkhw_desc = make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
|
||||
|
||||
#if 1
|
||||
device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
@@ -697,8 +709,8 @@ int main(int argc, char* argv[])
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
InLeftPads{},
|
||||
InRightPads{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
@@ -709,8 +721,8 @@ int main(int argc, char* argv[])
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
InLeftPads{},
|
||||
InRightPads{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(in_nchw_desc,
|
||||
@@ -721,58 +733,9 @@ int main(int argc, char* argv[])
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
InLeftPads{},
|
||||
InRightPads{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 1
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
out_data_t>
|
||||
|
||||
(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 1
|
||||
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
@@ -782,8 +745,8 @@ int main(int argc, char* argv[])
|
||||
out_nkhw_host,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{});
|
||||
InLeftPads{},
|
||||
InRightPads{});
|
||||
|
||||
check_error(out_nkhw_host, out_nkhw_device);
|
||||
|
||||
|
||||
410
driver/src/conv_driver_v2.cpp
Normal file
410
driver/src/conv_driver_v2.cpp
Normal file
@@ -0,0 +1,410 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
#define USE_DYNAMIC_MODE 1
|
||||
#define USE_CONV_FWD_V4R4_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4_NHWC 1
|
||||
#define USE_CONV_FWD_V4R5_NCHW 1
|
||||
#define USE_CONV_FWD_V5R1_NCHW 0
|
||||
|
||||
enum ConvForwardAlgo
|
||||
{
|
||||
V4R4NCHW,
|
||||
V4R4NHWC,
|
||||
V4R5NCHW,
|
||||
V5R1NCHW
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
#if USE_DYNAMIC_MODE
|
||||
// dynamic mode
|
||||
if(argc != 22)
|
||||
{
|
||||
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(atoi(argv[1]));
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(atoi(argv[2]));
|
||||
const bool do_verification = atoi(argv[3]);
|
||||
const int init_method = atoi(argv[4]);
|
||||
const bool do_log = atoi(argv[5]);
|
||||
const int nrepeat = atoi(argv[6]);
|
||||
|
||||
const index_t N = atoi(argv[7]);
|
||||
const index_t K = atoi(argv[8]);
|
||||
const index_t C = atoi(argv[9]);
|
||||
const index_t Y = atoi(argv[10]);
|
||||
const index_t X = atoi(argv[11]);
|
||||
const index_t Hi = atoi(argv[12]);
|
||||
const index_t Wi = atoi(argv[13]);
|
||||
|
||||
const index_t conv_stride_h = atoi(argv[14]);
|
||||
const index_t conv_stride_w = atoi(argv[15]);
|
||||
const index_t conv_dilation_h = atoi(argv[16]);
|
||||
const index_t conv_dilation_w = atoi(argv[17]);
|
||||
const index_t in_left_pad_h = atoi(argv[18]);
|
||||
const index_t in_left_pad_w = atoi(argv[19]);
|
||||
const index_t in_right_pad_h = atoi(argv[20]);
|
||||
const index_t in_right_pad_w = atoi(argv[21]);
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
#else
|
||||
// static mode
|
||||
if(argc < 7)
|
||||
{
|
||||
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(atoi(argv[1]));
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(atoi(argv[2]));
|
||||
const bool do_verification = atoi(argv[3]);
|
||||
const int init_method = atoi(argv[4]);
|
||||
const bool do_log = atoi(argv[5]);
|
||||
const int nrepeat = atoi(argv[6]);
|
||||
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t Hi = 17;
|
||||
constexpr index_t Wi = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 7;
|
||||
|
||||
const index_t conv_stride_h = 1;
|
||||
const index_t conv_stride_w = 1;
|
||||
const index_t conv_dilation_h = 1;
|
||||
const index_t conv_dilation_w = 1;
|
||||
const index_t in_left_pad_h = 0;
|
||||
const index_t in_left_pad_w = 3;
|
||||
const index_t in_right_pad_h = 0;
|
||||
const index_t in_right_pad_w = 3;
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 1
|
||||
constexpr index_t in_vector_size = 16;
|
||||
using in_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
||||
|
||||
switch(layout)
|
||||
{
|
||||
case ConvTensorLayout::NCHW:
|
||||
// NCHW
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(Wi);
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(X);
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
||||
break;
|
||||
case ConvTensorLayout::NHWC:
|
||||
// NHWC
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(X);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Wo);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(K);
|
||||
break;
|
||||
default: throw std::runtime_error("wrong! not implemented");
|
||||
}
|
||||
|
||||
Tensor<in_data_t> in(in_lengths_host);
|
||||
Tensor<in_data_t> wei(wei_lengths_host);
|
||||
Tensor<out_data_t> out_host(out_lengths_host);
|
||||
Tensor<out_data_t> out_device(out_lengths_host);
|
||||
|
||||
std::cout << "layout: " << layout << std::endl;
|
||||
ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: ");
|
||||
ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: ");
|
||||
ostream_HostTensorDescriptor(out_host.mDesc, std::cout << "out: ");
|
||||
print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w));
|
||||
print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w));
|
||||
print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w));
|
||||
print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w));
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei.GenerateTensorValue(gen_wei, num_thread);
|
||||
}
|
||||
}
|
||||
|
||||
auto f_make_for_device_nchw = [&]() {
|
||||
#if USE_DYNAMIC_MODE
|
||||
const auto in_lengths_dev = make_tuple(N, C, Hi, Wi);
|
||||
const auto wei_lengths_dev = make_tuple(K, C, Y, X);
|
||||
const auto out_lengths_dev = make_tuple(N, K, Ho, Wo);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
#else
|
||||
const auto in_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<C>{}, Number<Hi>{}, Number<Wi>{});
|
||||
const auto wei_lengths_dev = make_tuple(Number<K>{}, Number<C>{}, Number<Y>{}, Number<X>{});
|
||||
const auto out_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<K>{}, Number<Ho>{}, Number<Wo>{});
|
||||
const auto conv_strides_dev = make_tuple(Number<conv_stride_h>{}, Number<conv_stride_w>{});
|
||||
const auto conv_dilations_dev =
|
||||
make_tuple(Number<conv_dilation_h>{}, Number<conv_dilation_w>{});
|
||||
const auto in_left_pads_dev = make_tuple(Number<in_left_pad_h>{}, Number<in_left_pad_w>{});
|
||||
const auto in_right_pads_dev =
|
||||
make_tuple(Number<in_right_pad_h>{}, Number<in_right_pad_w>{});
|
||||
#endif
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
auto f_make_for_device_nhwc = [&]() {
|
||||
#if USE_DYNAMIC_MODE
|
||||
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
|
||||
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
|
||||
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
#else
|
||||
const auto in_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<Hi>{}, Number<Wi>{}, Number<C>{});
|
||||
const auto wei_lengths_dev = make_tuple(Number<K>{}, Number<Y>{}, Number<X>{}, Number<C>{});
|
||||
const auto out_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<Ho>{}, Number<Wo>{}, Number<K>{});
|
||||
const auto conv_strides_dev = make_tuple(Number<conv_stride_h>{}, Number<conv_stride_w>{});
|
||||
const auto conv_dilations_dev =
|
||||
make_tuple(Number<conv_dilation_h>{}, Number<conv_dilation_w>{});
|
||||
const auto in_left_pads_dev = make_tuple(Number<in_left_pad_h>{}, Number<in_left_pad_w>{});
|
||||
const auto in_right_pads_dev =
|
||||
make_tuple(Number<in_right_pad_h>{}, Number<in_right_pad_w>{});
|
||||
#endif
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
const auto nhwc_desc = f_make_for_device_nhwc();
|
||||
|
||||
#if USE_CONV_FWD_V4R4_NCHW
|
||||
if(algo == ConvForwardAlgo::V4R4NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R4_NHWC
|
||||
if(algo == ConvForwardAlgo::V4R4NHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R5_NCHW
|
||||
if(algo == ConvForwardAlgo::V4R5NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V5R1_NCHW
|
||||
if(algo == ConvForwardAlgo::V5R1NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_direct_convolution(in,
|
||||
wei,
|
||||
out_host,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
layout);
|
||||
|
||||
check_error(out_host, out_device);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRange(std::cout << "in : ", in.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "wei: ", wei.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,18 +3,6 @@
|
||||
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
template <typename X>
|
||||
HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens) : mLens(lens)
|
||||
{
|
||||
this->CalculateStrides();
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens, std::vector<Y> strides)
|
||||
: mLens(lens), mStrides(strides)
|
||||
{
|
||||
}
|
||||
|
||||
void HostTensorDescriptor::CalculateStrides()
|
||||
{
|
||||
mStrides.clear();
|
||||
@@ -45,3 +33,16 @@ std::size_t HostTensorDescriptor::GetElementSpace() const
|
||||
const std::vector<std::size_t>& HostTensorDescriptor::GetLengths() const { return mLens; }
|
||||
|
||||
const std::vector<std::size_t>& HostTensorDescriptor::GetStrides() const { return mStrides; }
|
||||
|
||||
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os)
|
||||
{
|
||||
os << "dim " << desc.GetNumOfDimension() << ", ";
|
||||
|
||||
os << "lengths {";
|
||||
LogRange(os, desc.GetLengths(), ", ");
|
||||
os << "}, ";
|
||||
|
||||
os << "strides {";
|
||||
LogRange(os, desc.GetStrides(), ", ");
|
||||
os << "}" << std::endl;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user