Hybrid direct + implicit GEMM forward convolution NCHWc v5r1 (#25)

* Hybrid direct + implicit GEMM forward convolution NCHWc v5r1. Input tensor bypass LDS. Support fp32/fp16/int8

[ROCm/composable_kernel commit: 792a20fa5b]
This commit is contained in:
zjing14
2021-04-07 16:47:29 -05:00
committed by GitHub
parent ca8a932775
commit 2457224dc9
9 changed files with 1059 additions and 155 deletions

View File

@@ -48,8 +48,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
@@ -62,8 +62,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
constexpr index_t N = 1;
constexpr index_t C = 16;
@@ -642,7 +642,7 @@ int main(int argc, char* argv[])
using out_data_t = int8_t;
#elif 1
using in_data_t = int8_t;
constexpr index_t in_vector_size = 4;
constexpr index_t in_vector_size = 16;
using acc_data_t = int32_t;
using out_data_t = int8_t;
#endif
@@ -741,7 +741,7 @@ int main(int argc, char* argv[])
LeftPads{},
RightPads{},
nrepeat);
#elif 1
#elif 0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
in_vector_size,
acc_data_t,