This commit is contained in:
Chao Liu
2019-03-02 17:27:37 -06:00
parent 4543d17a71
commit 5fd40ad768
22 changed files with 358 additions and 2719 deletions

View File

@@ -9,13 +9,9 @@
#include "conv_common.hip.hpp"
#include "device_direct_convolution_1.hpp"
#include "device_direct_convolution_2.hpp"
#include "device_implicit_gemm_convolution_1_nchw_kcsr_nkhw.hpp"
#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.hpp"
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp"
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp"
#include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw.hpp"
#include "device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp"
//#include "device_winograd_convolution.hip.hpp"
struct GeneratorTensor_1
{
@@ -154,8 +150,8 @@ template <class T, class LowerPads, class UpperPads>
void host_winograd_3x3_convolution(
const Tensor<T>& in_nchw, const Tensor<T>& wei_kcsr, Tensor<T>& out, LowerPads, UpperPads)
{
constexpr std::size_t OutTileSizeH = 2;
constexpr std::size_t OutTileSizeW = 2;
constexpr std::size_t HoPerTile = 2;
constexpr std::size_t WoPerTile = 2;
std::size_t N = in_nchw.mDesc.GetLengths()[0];
std::size_t C = in_nchw.mDesc.GetLengths()[1];
@@ -163,8 +159,8 @@ void host_winograd_3x3_convolution(
std::size_t WI = in_nchw.mDesc.GetLengths()[3];
std::size_t K = wei_kcsr.mDesc.GetLengths()[0];
std::size_t S = wei_kcsr.mDesc.GetLengths()[2];
std::size_t R = wei_kcsr.mDesc.GetLengths()[3];
std::size_t Y = wei_kcsr.mDesc.GetLengths()[2];
std::size_t X = wei_kcsr.mDesc.GetLengths()[3];
std::size_t HO = out.mDesc.GetLengths()[2];
std::size_t WO = out.mDesc.GetLengths()[3];
@@ -175,75 +171,91 @@ void host_winograd_3x3_convolution(
unsigned h_pad_up = UpperPads{}.Get(Number<0>{});
unsigned w_pad_up = UpperPads{}.Get(Number<1>{});
std::size_t InTileSizeH = OutTileSizeH + S - 1;
std::size_t InTileSizeW = OutTileSizeW + R - 1;
std::size_t HiPerTile = HoPerTile + Y - 1;
std::size_t WiPerTile = WoPerTile + X - 1;
std::size_t Y = (HO + OutTileSizeH - 1) / OutTileSizeH;
std::size_t X = (WO + OutTileSizeW - 1) / OutTileSizeW;
std::size_t HTile = (HO + HoPerTile - 1) / HoPerTile;
std::size_t WTile = (WO + WoPerTile - 1) / WoPerTile;
Tensor<T> in_hold({N, C, Y, X, InTileSizeH, InTileSizeW});
Tensor<T> in_transform({N, C, Y, X, InTileSizeH, InTileSizeW});
Tensor<T> wei_transform({K, C, InTileSizeH, InTileSizeW});
Tensor<T> out_transform({N, K, Y, X, InTileSizeH, InTileSizeH});
Tensor<T> out_hold({N, K, Y, X, OutTileSizeH, OutTileSizeW});
Tensor<T> in_hold({N, C, HTile, WTile, HiPerTile, WiPerTile});
Tensor<T> in_transform({N, C, HTile, WTile, HiPerTile, WiPerTile});
Tensor<T> wei_transform({K, C, HiPerTile, WiPerTile});
Tensor<T> out_transform({N, K, HTile, WTile, HiPerTile, HiPerTile});
Tensor<T> out_hold({N, K, HTile, WTile, HoPerTile, WoPerTile});
auto f_in_hold = [&](auto n, auto c, auto y, auto x) {
for(int j = 0; j < InTileSizeH; ++j)
auto f_in_hold = [&](auto n, auto c, auto htile, auto wtile) {
for(int j = 0; j < HiPerTile; ++j)
{
int hi = OutTileSizeH * y + j - h_pad_low;
for(int i = 0; i < InTileSizeW; ++i)
int hi = HoPerTile * htile + j - h_pad_low;
for(int i = 0; i < WiPerTile; ++i)
{
int wi = OutTileSizeW * x + i - w_pad_low;
int wi = WoPerTile * wtile + i - w_pad_low;
if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in_nchw.mDesc.GetLengths()[3])
{
in_hold(n, c, y, x, j, i) = in_nchw(n, c, hi, wi);
in_hold(n, c, htile, wtile, j, i) = in_nchw(n, c, hi, wi);
}
else
{
in_hold(n, c, y, x, j, i) = T(0);
in_hold(n, c, htile, wtile, j, i) = T(0);
}
}
}
};
auto f_in_transform = [&](auto n, auto c, auto y, auto x) {
in_transform(n, c, y, x, 0, 0) = in_hold(n, c, y, x, 0, 0) - in_hold(n, c, y, x, 0, 2) -
in_hold(n, c, y, x, 2, 0) + in_hold(n, c, y, x, 2, 2);
in_transform(n, c, y, x, 0, 1) = in_hold(n, c, y, x, 0, 1) + in_hold(n, c, y, x, 0, 2) -
in_hold(n, c, y, x, 2, 1) - in_hold(n, c, y, x, 2, 2);
in_transform(n, c, y, x, 0, 2) = -in_hold(n, c, y, x, 0, 1) + in_hold(n, c, y, x, 0, 2) +
in_hold(n, c, y, x, 2, 1) - in_hold(n, c, y, x, 2, 2);
in_transform(n, c, y, x, 0, 3) = in_hold(n, c, y, x, 0, 1) - in_hold(n, c, y, x, 0, 3) -
in_hold(n, c, y, x, 2, 1) + in_hold(n, c, y, x, 2, 3);
auto f_in_transform = [&](auto n, auto c, auto htile, auto wtile) {
in_transform(n, c, htile, wtile, 0, 0) =
in_hold(n, c, htile, wtile, 0, 0) - in_hold(n, c, htile, wtile, 0, 2) -
in_hold(n, c, htile, wtile, 2, 0) + in_hold(n, c, htile, wtile, 2, 2);
in_transform(n, c, htile, wtile, 0, 1) =
in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) -
in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2);
in_transform(n, c, htile, wtile, 0, 2) =
-in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) +
in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2);
in_transform(n, c, htile, wtile, 0, 3) =
in_hold(n, c, htile, wtile, 0, 1) - in_hold(n, c, htile, wtile, 0, 3) -
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 3);
in_transform(n, c, y, x, 1, 0) = in_hold(n, c, y, x, 1, 0) - in_hold(n, c, y, x, 1, 2) +
in_hold(n, c, y, x, 2, 0) - in_hold(n, c, y, x, 2, 2);
in_transform(n, c, y, x, 1, 1) = in_hold(n, c, y, x, 1, 1) + in_hold(n, c, y, x, 1, 2) +
in_hold(n, c, y, x, 2, 1) + in_hold(n, c, y, x, 2, 2);
in_transform(n, c, y, x, 1, 2) = -in_hold(n, c, y, x, 1, 1) + in_hold(n, c, y, x, 1, 2) -
in_hold(n, c, y, x, 2, 1) + in_hold(n, c, y, x, 2, 2);
in_transform(n, c, y, x, 1, 3) = in_hold(n, c, y, x, 1, 1) - in_hold(n, c, y, x, 1, 3) +
in_hold(n, c, y, x, 2, 1) - in_hold(n, c, y, x, 2, 3);
in_transform(n, c, htile, wtile, 1, 0) =
in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) +
in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2);
in_transform(n, c, htile, wtile, 1, 1) =
in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) +
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
in_transform(n, c, htile, wtile, 1, 2) =
-in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) -
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
in_transform(n, c, htile, wtile, 1, 3) =
in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) +
in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3);
in_transform(n, c, y, x, 2, 0) = -in_hold(n, c, y, x, 1, 0) + in_hold(n, c, y, x, 1, 2) +
in_hold(n, c, y, x, 2, 0) - in_hold(n, c, y, x, 2, 2);
in_transform(n, c, y, x, 2, 1) = -in_hold(n, c, y, x, 1, 1) - in_hold(n, c, y, x, 1, 2) +
in_hold(n, c, y, x, 2, 1) + in_hold(n, c, y, x, 2, 2);
in_transform(n, c, y, x, 2, 2) = in_hold(n, c, y, x, 1, 1) - in_hold(n, c, y, x, 1, 2) -
in_hold(n, c, y, x, 2, 1) + in_hold(n, c, y, x, 2, 2);
in_transform(n, c, y, x, 2, 3) = -in_hold(n, c, y, x, 1, 1) + in_hold(n, c, y, x, 1, 3) +
in_hold(n, c, y, x, 2, 1) - in_hold(n, c, y, x, 2, 3);
in_transform(n, c, htile, wtile, 2, 0) =
-in_hold(n, c, htile, wtile, 1, 0) + in_hold(n, c, htile, wtile, 1, 2) +
in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2);
in_transform(n, c, htile, wtile, 2, 1) =
-in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) +
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
in_transform(n, c, htile, wtile, 2, 2) =
in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) -
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
in_transform(n, c, htile, wtile, 2, 3) =
-in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 3) +
in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3);
in_transform(n, c, y, x, 3, 0) = in_hold(n, c, y, x, 1, 0) - in_hold(n, c, y, x, 1, 2) -
in_hold(n, c, y, x, 3, 0) + in_hold(n, c, y, x, 3, 2);
in_transform(n, c, y, x, 3, 1) = in_hold(n, c, y, x, 1, 1) + in_hold(n, c, y, x, 1, 2) -
in_hold(n, c, y, x, 3, 1) - in_hold(n, c, y, x, 3, 2);
in_transform(n, c, y, x, 3, 2) = -in_hold(n, c, y, x, 1, 1) + in_hold(n, c, y, x, 1, 2) +
in_hold(n, c, y, x, 3, 1) - in_hold(n, c, y, x, 3, 2);
in_transform(n, c, y, x, 3, 3) = in_hold(n, c, y, x, 1, 1) - in_hold(n, c, y, x, 1, 3) -
in_hold(n, c, y, x, 3, 1) + in_hold(n, c, y, x, 3, 3);
in_transform(n, c, htile, wtile, 3, 0) =
in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) -
in_hold(n, c, htile, wtile, 3, 0) + in_hold(n, c, htile, wtile, 3, 2);
in_transform(n, c, htile, wtile, 3, 1) =
in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) -
in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2);
in_transform(n, c, htile, wtile, 3, 2) =
-in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) +
in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2);
in_transform(n, c, htile, wtile, 3, 3) =
in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) -
in_hold(n, c, htile, wtile, 3, 1) + in_hold(n, c, htile, wtile, 3, 3);
};
auto f_wei_transform = [&](auto k, auto c) {
@@ -292,69 +304,69 @@ void host_winograd_3x3_convolution(
wei_transform(k, c, 3, 3) = wei_kcsr(k, c, 2, 2);
};
auto f_out_transform = [&](auto n, auto k, auto y, auto x) {
for(int j = 0; j < InTileSizeH; ++j)
auto f_out_transform = [&](auto n, auto k, auto htile, auto wtile) {
for(int j = 0; j < HiPerTile; ++j)
{
for(int i = 0; i < InTileSizeW; ++i)
for(int i = 0; i < WiPerTile; ++i)
{
double v = 0;
for(int c = 0; c < C; ++c)
{
v += in_transform(n, c, y, x, j, i) * wei_transform(k, c, j, i);
v += in_transform(n, c, htile, wtile, j, i) * wei_transform(k, c, j, i);
}
out_transform(n, k, y, x, j, i) = v;
out_transform(n, k, htile, wtile, j, i) = v;
}
}
};
auto f_out_hold = [&](auto n, auto k, auto y, auto x) {
out_hold(n, k, y, x, 0, 0) =
out_transform(n, k, y, x, 0, 0) + out_transform(n, k, y, x, 0, 1) +
out_transform(n, k, y, x, 0, 2) + out_transform(n, k, y, x, 1, 0) +
out_transform(n, k, y, x, 1, 1) + out_transform(n, k, y, x, 1, 2) +
out_transform(n, k, y, x, 2, 0) + out_transform(n, k, y, x, 2, 1) +
out_transform(n, k, y, x, 2, 2);
out_hold(n, k, y, x, 0, 1) =
out_transform(n, k, y, x, 0, 1) - out_transform(n, k, y, x, 0, 2) -
out_transform(n, k, y, x, 0, 3) + out_transform(n, k, y, x, 1, 1) -
out_transform(n, k, y, x, 1, 2) - out_transform(n, k, y, x, 1, 3) +
out_transform(n, k, y, x, 2, 1) - out_transform(n, k, y, x, 2, 2) -
out_transform(n, k, y, x, 2, 3);
out_hold(n, k, y, x, 1, 0) =
out_transform(n, k, y, x, 1, 0) + out_transform(n, k, y, x, 1, 1) +
out_transform(n, k, y, x, 1, 2) - out_transform(n, k, y, x, 2, 0) -
out_transform(n, k, y, x, 2, 1) - out_transform(n, k, y, x, 2, 2) -
out_transform(n, k, y, x, 3, 0) - out_transform(n, k, y, x, 3, 1) -
out_transform(n, k, y, x, 3, 2);
out_hold(n, k, y, x, 1, 1) =
out_transform(n, k, y, x, 1, 1) - out_transform(n, k, y, x, 1, 2) -
out_transform(n, k, y, x, 1, 3) - out_transform(n, k, y, x, 2, 1) +
out_transform(n, k, y, x, 2, 2) + out_transform(n, k, y, x, 2, 3) -
out_transform(n, k, y, x, 3, 1) + out_transform(n, k, y, x, 3, 2) +
out_transform(n, k, y, x, 3, 3);
auto f_out_hold = [&](auto n, auto k, auto htile, auto wtile) {
out_hold(n, k, htile, wtile, 0, 0) =
out_transform(n, k, htile, wtile, 0, 0) + out_transform(n, k, htile, wtile, 0, 1) +
out_transform(n, k, htile, wtile, 0, 2) + out_transform(n, k, htile, wtile, 1, 0) +
out_transform(n, k, htile, wtile, 1, 1) + out_transform(n, k, htile, wtile, 1, 2) +
out_transform(n, k, htile, wtile, 2, 0) + out_transform(n, k, htile, wtile, 2, 1) +
out_transform(n, k, htile, wtile, 2, 2);
out_hold(n, k, htile, wtile, 0, 1) =
out_transform(n, k, htile, wtile, 0, 1) - out_transform(n, k, htile, wtile, 0, 2) -
out_transform(n, k, htile, wtile, 0, 3) + out_transform(n, k, htile, wtile, 1, 1) -
out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 1, 3) +
out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) -
out_transform(n, k, htile, wtile, 2, 3);
out_hold(n, k, htile, wtile, 1, 0) =
out_transform(n, k, htile, wtile, 1, 0) + out_transform(n, k, htile, wtile, 1, 1) +
out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 2, 0) -
out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) -
out_transform(n, k, htile, wtile, 3, 0) - out_transform(n, k, htile, wtile, 3, 1) -
out_transform(n, k, htile, wtile, 3, 2);
out_hold(n, k, htile, wtile, 1, 1) =
out_transform(n, k, htile, wtile, 1, 1) - out_transform(n, k, htile, wtile, 1, 2) -
out_transform(n, k, htile, wtile, 1, 3) - out_transform(n, k, htile, wtile, 2, 1) +
out_transform(n, k, htile, wtile, 2, 2) + out_transform(n, k, htile, wtile, 2, 3) -
out_transform(n, k, htile, wtile, 3, 1) + out_transform(n, k, htile, wtile, 3, 2) +
out_transform(n, k, htile, wtile, 3, 3);
};
auto f_out = [&](auto n, auto k, auto y, auto x) {
for(int j = 0; j < OutTileSizeH; ++j)
auto f_out = [&](auto n, auto k, auto htile, auto wtile) {
for(int j = 0; j < HoPerTile; ++j)
{
std::size_t ho = OutTileSizeH * y + j;
for(int i = 0; i < OutTileSizeW; ++i)
std::size_t ho = HoPerTile * htile + j;
for(int i = 0; i < WoPerTile; ++i)
{
std::size_t wo = OutTileSizeW * x + i;
out(n, k, ho, wo) = out_hold(n, k, y, x, j, i);
std::size_t wo = WoPerTile * wtile + i;
out(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i);
}
}
};
std::size_t num_thread = std::thread::hardware_concurrency();
make_ParallelTensorFunctor(f_in_hold, N, C, Y, X)(num_thread);
make_ParallelTensorFunctor(f_in_transform, N, C, Y, X)(num_thread);
make_ParallelTensorFunctor(f_in_hold, N, C, HTile, WTile)(num_thread);
make_ParallelTensorFunctor(f_in_transform, N, C, HTile, WTile)(num_thread);
make_ParallelTensorFunctor(f_wei_transform, K, C)(num_thread);
make_ParallelTensorFunctor(f_out_transform, N, K, Y, X)(num_thread);
make_ParallelTensorFunctor(f_out_hold, N, K, Y, X)(num_thread);
make_ParallelTensorFunctor(f_out, N, K, Y, X)(num_thread);
make_ParallelTensorFunctor(f_out_transform, N, K, HTile, WTile)(num_thread);
make_ParallelTensorFunctor(f_out_hold, N, K, HTile, WTile)(num_thread);
make_ParallelTensorFunctor(f_out, N, K, HTile, WTile)(num_thread);
}
template <class T>
@@ -387,8 +399,8 @@ int main(int argc, char* argv[])
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 1;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
@@ -399,8 +411,8 @@ int main(int argc, char* argv[])
constexpr unsigned HI = 34;
constexpr unsigned WI = 34;
constexpr unsigned K = 64;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
@@ -411,8 +423,8 @@ int main(int argc, char* argv[])
constexpr unsigned HI = 56;
constexpr unsigned WI = 56;
constexpr unsigned K = 64;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
#elif 0
// 3x3, 58x58
constexpr unsigned N = 64;
@@ -420,8 +432,8 @@ int main(int argc, char* argv[])
constexpr unsigned HI = 58;
constexpr unsigned WI = 58;
constexpr unsigned K = 64;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
#elif 0
// 5x5, 36x36
constexpr unsigned N = 64;
@@ -429,8 +441,8 @@ int main(int argc, char* argv[])
constexpr unsigned HI = 36;
constexpr unsigned WI = 36;
constexpr unsigned K = 64;
constexpr unsigned S = 5;
constexpr unsigned R = 5;
constexpr unsigned Y = 5;
constexpr unsigned X = 5;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
@@ -441,8 +453,8 @@ int main(int argc, char* argv[])
constexpr unsigned HI = 38;
constexpr unsigned WI = 38;
constexpr unsigned K = 64;
constexpr unsigned S = 7;
constexpr unsigned R = 7;
constexpr unsigned Y = 7;
constexpr unsigned X = 7;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
@@ -453,8 +465,8 @@ int main(int argc, char* argv[])
constexpr unsigned HI = 58;
constexpr unsigned WI = 58;
constexpr unsigned K = 256;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
#elif 0
// 3x3 filter, 58x58 image, 0x0 padding
constexpr unsigned N = 16;
@@ -462,8 +474,8 @@ int main(int argc, char* argv[])
constexpr unsigned HI = 58;
constexpr unsigned WI = 58;
constexpr unsigned K = 256;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
@@ -474,8 +486,8 @@ int main(int argc, char* argv[])
constexpr unsigned HI = 56;
constexpr unsigned WI = 56;
constexpr unsigned K = 256;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
@@ -486,8 +498,8 @@ int main(int argc, char* argv[])
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 512;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
@@ -498,8 +510,8 @@ int main(int argc, char* argv[])
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 512;
constexpr unsigned S = 1;
constexpr unsigned R = 1;
constexpr unsigned Y = 1;
constexpr unsigned X = 1;
constexpr unsigned HPad = 0;
constexpr unsigned WPad = 0;
@@ -510,8 +522,8 @@ int main(int argc, char* argv[])
constexpr unsigned HI = 20;
constexpr unsigned WI = 84;
constexpr unsigned K = 256;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
@@ -522,8 +534,8 @@ int main(int argc, char* argv[])
constexpr unsigned HI = 112;
constexpr unsigned WI = 112;
constexpr unsigned K = 128;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
constexpr unsigned Y = 3;
constexpr unsigned X = 3;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
@@ -534,8 +546,8 @@ int main(int argc, char* argv[])
constexpr unsigned HI = 20;
constexpr unsigned WI = 86;
constexpr unsigned K = 512;
constexpr unsigned S = 5;
constexpr unsigned R = 5;
constexpr unsigned Y = 5;
constexpr unsigned X = 5;
constexpr unsigned HPad = 1;
constexpr unsigned WPad = 1;
@@ -546,8 +558,8 @@ int main(int argc, char* argv[])
constexpr unsigned HI = 28;
constexpr unsigned WI = 28;
constexpr unsigned K = 32;
constexpr unsigned S = 5;
constexpr unsigned R = 5;
constexpr unsigned Y = 5;
constexpr unsigned X = 5;
constexpr unsigned HPad = 2;
constexpr unsigned WPad = 2;
@@ -557,7 +569,7 @@ int main(int argc, char* argv[])
auto upper_pads = Sequence<HPad, WPad>{};
auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence<N, C, HI, WI>{});
auto wei_kcsr_desc = make_ConstantTensorDescriptor(Sequence<K, C, S, R>{});
auto wei_kcsr_desc = make_ConstantTensorDescriptor(Sequence<K, C, Y, X>{});
auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor(
in_nchw_desc, wei_kcsr_desc, lower_pads, upper_pads);
@@ -600,14 +612,8 @@ int main(int argc, char* argv[])
device_direct_convolution_1
#elif 0
device_direct_convolution_2
#elif 0
device_implicit_gemm_convolution_1_nchw_kcsr_nkhw
#elif 0
device_implicit_gemm_convolution_1_nchw_srck_nkhw
#elif 1
device_implicit_gemm_convolution_1_chwn_csrk_khwn
#elif 0
device_implicit_gemm_convolution_2_cnhw_csrk_knhw
#elif 0
device_implicit_gemm_convolution_2_chwn_csrk_khwn
#endif
@@ -627,7 +633,7 @@ int main(int argc, char* argv[])
if(do_verification)
{
if(S == 3 && R == 3)
if(Y == 3 && X == 3)
{
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
}