added strides and dilations suppport to implicit gemm v4

This commit is contained in:
Chao Liu
2019-06-13 16:20:10 -05:00
parent 1566b31736
commit b1cb48a04d
5 changed files with 96 additions and 36 deletions

View File

@@ -36,11 +36,14 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(InDesc, WeiDe
return make_ConstantTensorDescriptor_packed(Sequence<N, K, HO, WO>{});
}
template <class InDesc, class WeiDesc, class LowerPads, class UpperPads>
constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor(InDesc,
WeiDesc,
LowerPads,
UpperPads)
template <class InDesc,
class WeiDesc,
class ConvStrides,
class ConvDilations,
class LowerPads,
class UpperPads>
constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor(
InDesc, WeiDesc, ConvStrides, ConvDilations, LowerPads, UpperPads)
{
constexpr auto in_desc = InDesc{};
constexpr auto wei_desc = WeiDesc{};
@@ -55,24 +58,27 @@ constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor(
static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1),
"input & weight dimension not consistent");
constexpr auto N = in_desc.GetLength(I0);
constexpr auto HI = in_desc.GetLength(I2);
constexpr auto WI = in_desc.GetLength(I3);
constexpr index_t N = in_desc.GetLength(I0);
constexpr index_t Hi = in_desc.GetLength(I2);
constexpr index_t Wi = in_desc.GetLength(I3);
constexpr auto K = wei_desc.GetLength(I0);
constexpr auto Y = wei_desc.GetLength(I2);
constexpr auto X = wei_desc.GetLength(I3);
constexpr index_t K = wei_desc.GetLength(I0);
constexpr index_t Y = wei_desc.GetLength(I2);
constexpr index_t X = wei_desc.GetLength(I3);
constexpr auto HPadLow = LowerPads{}.Get(I0);
constexpr auto WPadLow = LowerPads{}.Get(I1);
constexpr index_t HPadLow = LowerPads{}.Get(I0);
constexpr index_t WPadLow = LowerPads{}.Get(I1);
constexpr auto HPadUp = UpperPads{}.Get(I0);
constexpr auto WPadUp = UpperPads{}.Get(I1);
constexpr index_t HPadUp = UpperPads{}.Get(I0);
constexpr index_t WPadUp = UpperPads{}.Get(I1);
constexpr auto HO = HI + HPadLow + HPadUp + 1 - Y;
constexpr auto WO = WI + WPadLow + WPadUp + 1 - X;
constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1;
constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1;
return make_ConstantTensorDescriptor_packed(Sequence<N, K, HO, WO>{});
constexpr index_t Ho = (Hi + HPadLow + HPadUp - YEff) / ConvStrides{}[0] + 1;
constexpr index_t Wo = (Wi + WPadLow + WPadUp - XEff) / ConvStrides{}[1] + 1;
return make_ConstantTensorDescriptor_packed(Sequence<N, K, Ho, Wo>{});
}
template <class InDesc, class WeiDesc, class OutDesc>

View File

@@ -8,13 +8,20 @@
using namespace ck;
template <class T, class InDesc, class WeiDesc, class OutDesc>
template <class T,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations>
void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
@@ -107,6 +114,8 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
BPerBlock,
KPerBlock,
CPerBlock,