adding implicit gemm

[ROCm/composable_kernel commit: e7b8705b91]
This commit is contained in:
Chao Liu
2019-01-15 18:11:41 -06:00
parent aa885b185d
commit 8b3c613be1
10 changed files with 510 additions and 231 deletions

View File

@@ -101,10 +101,10 @@ __device__ void threadwise_direct_convolution_2(InDesc,
Float p_wei_reg[wei_reg_desc.GetElementSpace()];
// copy input tensor into register
threadwise_4d_tensor_copy(in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc);
threadwise_4d_tensor_copy(in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc.GetLengths());
// copy input tensor into register
threadwise_4d_tensor_copy(wei_desc, p_wei, wei_reg_desc, p_wei_reg, wei_reg_desc);
threadwise_4d_tensor_copy(wei_desc, p_wei, wei_reg_desc, p_wei_reg, wei_reg_desc.GetLengths());
// do convolution
threadwise_direct_convolution_1(
@@ -159,14 +159,14 @@ __device__ void threadwise_direct_convolution_3(InDesc,
p_in + in_desc.Get1dIndex(0, 0, s, 0),
in_reg_desc,
p_in_reg,
in_reg_desc);
in_reg_desc.GetLengths());
// read first 1x1 weight
threadwise_4d_tensor_copy(wei_desc,
p_wei + wei_desc.Get1dIndex(0, 0, s, 0),
wei_reg_desc,
p_wei_reg,
wei_reg_desc);
wei_reg_desc.GetLengths());
// do first 1x1 conv
threadwise_direct_convolution_1(
@@ -180,7 +180,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
p_wei + wei_desc.Get1dIndex(0, 0, s, r),
wei_reg_desc,
p_wei_reg,
wei_reg_desc);
wei_reg_desc.GetLengths());
// shift old input to the left
threadwise_4d_tensor_shift_down(in_reg_desc, p_in_reg, I3, Number<in_w_new_read>{});
@@ -192,7 +192,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
in_reg_desc,
p_in_reg +
in_reg_desc.Get1dIndex(0, 0, 0, in_reg_desc.GetLength(I3) - in_w_new_read),
in_desc_reg_new_read);
in_desc_reg_new_read.GetLengths());
// do 1x1 conv
threadwise_direct_convolution_1(
@@ -211,11 +211,14 @@ __device__ void threadwise_direct_convolution_3(InDesc,
p_wei + wei_desc.Get1dIndex(0, 0, s, r),
wei_reg_desc,
p_wei_reg,
wei_reg_desc);
wei_reg_desc.GetLengths());
// read new input
threadwise_4d_tensor_copy(
in_desc, p_in + in_desc.Get1dIndex(0, 0, s, r), in_reg_desc, p_in_reg, in_reg_desc);
threadwise_4d_tensor_copy(in_desc,
p_in + in_desc.Get1dIndex(0, 0, s, r),
in_reg_desc,
p_in_reg,
in_reg_desc.GetLengths());
// do 1x1 conv
threadwise_direct_convolution_1(