mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
clean up
This commit is contained in:
@@ -9,13 +9,9 @@
|
||||
#include "conv_common.hip.hpp"
|
||||
#include "device_direct_convolution_1.hpp"
|
||||
#include "device_direct_convolution_2.hpp"
|
||||
#include "device_implicit_gemm_convolution_1_nchw_kcsr_nkhw.hpp"
|
||||
#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.hpp"
|
||||
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn.hpp"
|
||||
#include "device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded.hpp"
|
||||
#include "device_implicit_gemm_convolution_2_cnhw_csrk_knhw.hpp"
|
||||
#include "device_implicit_gemm_convolution_2_chwn_csrk_khwn.hpp"
|
||||
//#include "device_winograd_convolution.hip.hpp"
|
||||
|
||||
struct GeneratorTensor_1
|
||||
{
|
||||
@@ -154,8 +150,8 @@ template <class T, class LowerPads, class UpperPads>
|
||||
void host_winograd_3x3_convolution(
|
||||
const Tensor<T>& in_nchw, const Tensor<T>& wei_kcsr, Tensor<T>& out, LowerPads, UpperPads)
|
||||
{
|
||||
constexpr std::size_t OutTileSizeH = 2;
|
||||
constexpr std::size_t OutTileSizeW = 2;
|
||||
constexpr std::size_t HoPerTile = 2;
|
||||
constexpr std::size_t WoPerTile = 2;
|
||||
|
||||
std::size_t N = in_nchw.mDesc.GetLengths()[0];
|
||||
std::size_t C = in_nchw.mDesc.GetLengths()[1];
|
||||
@@ -163,8 +159,8 @@ void host_winograd_3x3_convolution(
|
||||
std::size_t WI = in_nchw.mDesc.GetLengths()[3];
|
||||
|
||||
std::size_t K = wei_kcsr.mDesc.GetLengths()[0];
|
||||
std::size_t S = wei_kcsr.mDesc.GetLengths()[2];
|
||||
std::size_t R = wei_kcsr.mDesc.GetLengths()[3];
|
||||
std::size_t Y = wei_kcsr.mDesc.GetLengths()[2];
|
||||
std::size_t X = wei_kcsr.mDesc.GetLengths()[3];
|
||||
|
||||
std::size_t HO = out.mDesc.GetLengths()[2];
|
||||
std::size_t WO = out.mDesc.GetLengths()[3];
|
||||
@@ -175,75 +171,91 @@ void host_winograd_3x3_convolution(
|
||||
unsigned h_pad_up = UpperPads{}.Get(Number<0>{});
|
||||
unsigned w_pad_up = UpperPads{}.Get(Number<1>{});
|
||||
|
||||
std::size_t InTileSizeH = OutTileSizeH + S - 1;
|
||||
std::size_t InTileSizeW = OutTileSizeW + R - 1;
|
||||
std::size_t HiPerTile = HoPerTile + Y - 1;
|
||||
std::size_t WiPerTile = WoPerTile + X - 1;
|
||||
|
||||
std::size_t Y = (HO + OutTileSizeH - 1) / OutTileSizeH;
|
||||
std::size_t X = (WO + OutTileSizeW - 1) / OutTileSizeW;
|
||||
std::size_t HTile = (HO + HoPerTile - 1) / HoPerTile;
|
||||
std::size_t WTile = (WO + WoPerTile - 1) / WoPerTile;
|
||||
|
||||
Tensor<T> in_hold({N, C, Y, X, InTileSizeH, InTileSizeW});
|
||||
Tensor<T> in_transform({N, C, Y, X, InTileSizeH, InTileSizeW});
|
||||
Tensor<T> wei_transform({K, C, InTileSizeH, InTileSizeW});
|
||||
Tensor<T> out_transform({N, K, Y, X, InTileSizeH, InTileSizeH});
|
||||
Tensor<T> out_hold({N, K, Y, X, OutTileSizeH, OutTileSizeW});
|
||||
Tensor<T> in_hold({N, C, HTile, WTile, HiPerTile, WiPerTile});
|
||||
Tensor<T> in_transform({N, C, HTile, WTile, HiPerTile, WiPerTile});
|
||||
Tensor<T> wei_transform({K, C, HiPerTile, WiPerTile});
|
||||
Tensor<T> out_transform({N, K, HTile, WTile, HiPerTile, HiPerTile});
|
||||
Tensor<T> out_hold({N, K, HTile, WTile, HoPerTile, WoPerTile});
|
||||
|
||||
auto f_in_hold = [&](auto n, auto c, auto y, auto x) {
|
||||
for(int j = 0; j < InTileSizeH; ++j)
|
||||
auto f_in_hold = [&](auto n, auto c, auto htile, auto wtile) {
|
||||
for(int j = 0; j < HiPerTile; ++j)
|
||||
{
|
||||
int hi = OutTileSizeH * y + j - h_pad_low;
|
||||
for(int i = 0; i < InTileSizeW; ++i)
|
||||
int hi = HoPerTile * htile + j - h_pad_low;
|
||||
for(int i = 0; i < WiPerTile; ++i)
|
||||
{
|
||||
int wi = OutTileSizeW * x + i - w_pad_low;
|
||||
int wi = WoPerTile * wtile + i - w_pad_low;
|
||||
|
||||
if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in_nchw.mDesc.GetLengths()[3])
|
||||
{
|
||||
in_hold(n, c, y, x, j, i) = in_nchw(n, c, hi, wi);
|
||||
in_hold(n, c, htile, wtile, j, i) = in_nchw(n, c, hi, wi);
|
||||
}
|
||||
else
|
||||
{
|
||||
in_hold(n, c, y, x, j, i) = T(0);
|
||||
in_hold(n, c, htile, wtile, j, i) = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto f_in_transform = [&](auto n, auto c, auto y, auto x) {
|
||||
in_transform(n, c, y, x, 0, 0) = in_hold(n, c, y, x, 0, 0) - in_hold(n, c, y, x, 0, 2) -
|
||||
in_hold(n, c, y, x, 2, 0) + in_hold(n, c, y, x, 2, 2);
|
||||
in_transform(n, c, y, x, 0, 1) = in_hold(n, c, y, x, 0, 1) + in_hold(n, c, y, x, 0, 2) -
|
||||
in_hold(n, c, y, x, 2, 1) - in_hold(n, c, y, x, 2, 2);
|
||||
in_transform(n, c, y, x, 0, 2) = -in_hold(n, c, y, x, 0, 1) + in_hold(n, c, y, x, 0, 2) +
|
||||
in_hold(n, c, y, x, 2, 1) - in_hold(n, c, y, x, 2, 2);
|
||||
in_transform(n, c, y, x, 0, 3) = in_hold(n, c, y, x, 0, 1) - in_hold(n, c, y, x, 0, 3) -
|
||||
in_hold(n, c, y, x, 2, 1) + in_hold(n, c, y, x, 2, 3);
|
||||
auto f_in_transform = [&](auto n, auto c, auto htile, auto wtile) {
|
||||
in_transform(n, c, htile, wtile, 0, 0) =
|
||||
in_hold(n, c, htile, wtile, 0, 0) - in_hold(n, c, htile, wtile, 0, 2) -
|
||||
in_hold(n, c, htile, wtile, 2, 0) + in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 0, 1) =
|
||||
in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) -
|
||||
in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 0, 2) =
|
||||
-in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) +
|
||||
in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 0, 3) =
|
||||
in_hold(n, c, htile, wtile, 0, 1) - in_hold(n, c, htile, wtile, 0, 3) -
|
||||
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 3);
|
||||
|
||||
in_transform(n, c, y, x, 1, 0) = in_hold(n, c, y, x, 1, 0) - in_hold(n, c, y, x, 1, 2) +
|
||||
in_hold(n, c, y, x, 2, 0) - in_hold(n, c, y, x, 2, 2);
|
||||
in_transform(n, c, y, x, 1, 1) = in_hold(n, c, y, x, 1, 1) + in_hold(n, c, y, x, 1, 2) +
|
||||
in_hold(n, c, y, x, 2, 1) + in_hold(n, c, y, x, 2, 2);
|
||||
in_transform(n, c, y, x, 1, 2) = -in_hold(n, c, y, x, 1, 1) + in_hold(n, c, y, x, 1, 2) -
|
||||
in_hold(n, c, y, x, 2, 1) + in_hold(n, c, y, x, 2, 2);
|
||||
in_transform(n, c, y, x, 1, 3) = in_hold(n, c, y, x, 1, 1) - in_hold(n, c, y, x, 1, 3) +
|
||||
in_hold(n, c, y, x, 2, 1) - in_hold(n, c, y, x, 2, 3);
|
||||
in_transform(n, c, htile, wtile, 1, 0) =
|
||||
in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) +
|
||||
in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 1, 1) =
|
||||
in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) +
|
||||
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 1, 2) =
|
||||
-in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) -
|
||||
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 1, 3) =
|
||||
in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) +
|
||||
in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3);
|
||||
|
||||
in_transform(n, c, y, x, 2, 0) = -in_hold(n, c, y, x, 1, 0) + in_hold(n, c, y, x, 1, 2) +
|
||||
in_hold(n, c, y, x, 2, 0) - in_hold(n, c, y, x, 2, 2);
|
||||
in_transform(n, c, y, x, 2, 1) = -in_hold(n, c, y, x, 1, 1) - in_hold(n, c, y, x, 1, 2) +
|
||||
in_hold(n, c, y, x, 2, 1) + in_hold(n, c, y, x, 2, 2);
|
||||
in_transform(n, c, y, x, 2, 2) = in_hold(n, c, y, x, 1, 1) - in_hold(n, c, y, x, 1, 2) -
|
||||
in_hold(n, c, y, x, 2, 1) + in_hold(n, c, y, x, 2, 2);
|
||||
in_transform(n, c, y, x, 2, 3) = -in_hold(n, c, y, x, 1, 1) + in_hold(n, c, y, x, 1, 3) +
|
||||
in_hold(n, c, y, x, 2, 1) - in_hold(n, c, y, x, 2, 3);
|
||||
in_transform(n, c, htile, wtile, 2, 0) =
|
||||
-in_hold(n, c, htile, wtile, 1, 0) + in_hold(n, c, htile, wtile, 1, 2) +
|
||||
in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 2, 1) =
|
||||
-in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) +
|
||||
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 2, 2) =
|
||||
in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) -
|
||||
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 2, 3) =
|
||||
-in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 3) +
|
||||
in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3);
|
||||
|
||||
in_transform(n, c, y, x, 3, 0) = in_hold(n, c, y, x, 1, 0) - in_hold(n, c, y, x, 1, 2) -
|
||||
in_hold(n, c, y, x, 3, 0) + in_hold(n, c, y, x, 3, 2);
|
||||
in_transform(n, c, y, x, 3, 1) = in_hold(n, c, y, x, 1, 1) + in_hold(n, c, y, x, 1, 2) -
|
||||
in_hold(n, c, y, x, 3, 1) - in_hold(n, c, y, x, 3, 2);
|
||||
in_transform(n, c, y, x, 3, 2) = -in_hold(n, c, y, x, 1, 1) + in_hold(n, c, y, x, 1, 2) +
|
||||
in_hold(n, c, y, x, 3, 1) - in_hold(n, c, y, x, 3, 2);
|
||||
in_transform(n, c, y, x, 3, 3) = in_hold(n, c, y, x, 1, 1) - in_hold(n, c, y, x, 1, 3) -
|
||||
in_hold(n, c, y, x, 3, 1) + in_hold(n, c, y, x, 3, 3);
|
||||
in_transform(n, c, htile, wtile, 3, 0) =
|
||||
in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) -
|
||||
in_hold(n, c, htile, wtile, 3, 0) + in_hold(n, c, htile, wtile, 3, 2);
|
||||
in_transform(n, c, htile, wtile, 3, 1) =
|
||||
in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) -
|
||||
in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2);
|
||||
in_transform(n, c, htile, wtile, 3, 2) =
|
||||
-in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) +
|
||||
in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2);
|
||||
in_transform(n, c, htile, wtile, 3, 3) =
|
||||
in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) -
|
||||
in_hold(n, c, htile, wtile, 3, 1) + in_hold(n, c, htile, wtile, 3, 3);
|
||||
};
|
||||
|
||||
auto f_wei_transform = [&](auto k, auto c) {
|
||||
@@ -292,69 +304,69 @@ void host_winograd_3x3_convolution(
|
||||
wei_transform(k, c, 3, 3) = wei_kcsr(k, c, 2, 2);
|
||||
};
|
||||
|
||||
auto f_out_transform = [&](auto n, auto k, auto y, auto x) {
|
||||
for(int j = 0; j < InTileSizeH; ++j)
|
||||
auto f_out_transform = [&](auto n, auto k, auto htile, auto wtile) {
|
||||
for(int j = 0; j < HiPerTile; ++j)
|
||||
{
|
||||
for(int i = 0; i < InTileSizeW; ++i)
|
||||
for(int i = 0; i < WiPerTile; ++i)
|
||||
{
|
||||
double v = 0;
|
||||
for(int c = 0; c < C; ++c)
|
||||
{
|
||||
v += in_transform(n, c, y, x, j, i) * wei_transform(k, c, j, i);
|
||||
v += in_transform(n, c, htile, wtile, j, i) * wei_transform(k, c, j, i);
|
||||
}
|
||||
|
||||
out_transform(n, k, y, x, j, i) = v;
|
||||
out_transform(n, k, htile, wtile, j, i) = v;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto f_out_hold = [&](auto n, auto k, auto y, auto x) {
|
||||
out_hold(n, k, y, x, 0, 0) =
|
||||
out_transform(n, k, y, x, 0, 0) + out_transform(n, k, y, x, 0, 1) +
|
||||
out_transform(n, k, y, x, 0, 2) + out_transform(n, k, y, x, 1, 0) +
|
||||
out_transform(n, k, y, x, 1, 1) + out_transform(n, k, y, x, 1, 2) +
|
||||
out_transform(n, k, y, x, 2, 0) + out_transform(n, k, y, x, 2, 1) +
|
||||
out_transform(n, k, y, x, 2, 2);
|
||||
out_hold(n, k, y, x, 0, 1) =
|
||||
out_transform(n, k, y, x, 0, 1) - out_transform(n, k, y, x, 0, 2) -
|
||||
out_transform(n, k, y, x, 0, 3) + out_transform(n, k, y, x, 1, 1) -
|
||||
out_transform(n, k, y, x, 1, 2) - out_transform(n, k, y, x, 1, 3) +
|
||||
out_transform(n, k, y, x, 2, 1) - out_transform(n, k, y, x, 2, 2) -
|
||||
out_transform(n, k, y, x, 2, 3);
|
||||
out_hold(n, k, y, x, 1, 0) =
|
||||
out_transform(n, k, y, x, 1, 0) + out_transform(n, k, y, x, 1, 1) +
|
||||
out_transform(n, k, y, x, 1, 2) - out_transform(n, k, y, x, 2, 0) -
|
||||
out_transform(n, k, y, x, 2, 1) - out_transform(n, k, y, x, 2, 2) -
|
||||
out_transform(n, k, y, x, 3, 0) - out_transform(n, k, y, x, 3, 1) -
|
||||
out_transform(n, k, y, x, 3, 2);
|
||||
out_hold(n, k, y, x, 1, 1) =
|
||||
out_transform(n, k, y, x, 1, 1) - out_transform(n, k, y, x, 1, 2) -
|
||||
out_transform(n, k, y, x, 1, 3) - out_transform(n, k, y, x, 2, 1) +
|
||||
out_transform(n, k, y, x, 2, 2) + out_transform(n, k, y, x, 2, 3) -
|
||||
out_transform(n, k, y, x, 3, 1) + out_transform(n, k, y, x, 3, 2) +
|
||||
out_transform(n, k, y, x, 3, 3);
|
||||
auto f_out_hold = [&](auto n, auto k, auto htile, auto wtile) {
|
||||
out_hold(n, k, htile, wtile, 0, 0) =
|
||||
out_transform(n, k, htile, wtile, 0, 0) + out_transform(n, k, htile, wtile, 0, 1) +
|
||||
out_transform(n, k, htile, wtile, 0, 2) + out_transform(n, k, htile, wtile, 1, 0) +
|
||||
out_transform(n, k, htile, wtile, 1, 1) + out_transform(n, k, htile, wtile, 1, 2) +
|
||||
out_transform(n, k, htile, wtile, 2, 0) + out_transform(n, k, htile, wtile, 2, 1) +
|
||||
out_transform(n, k, htile, wtile, 2, 2);
|
||||
out_hold(n, k, htile, wtile, 0, 1) =
|
||||
out_transform(n, k, htile, wtile, 0, 1) - out_transform(n, k, htile, wtile, 0, 2) -
|
||||
out_transform(n, k, htile, wtile, 0, 3) + out_transform(n, k, htile, wtile, 1, 1) -
|
||||
out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 1, 3) +
|
||||
out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) -
|
||||
out_transform(n, k, htile, wtile, 2, 3);
|
||||
out_hold(n, k, htile, wtile, 1, 0) =
|
||||
out_transform(n, k, htile, wtile, 1, 0) + out_transform(n, k, htile, wtile, 1, 1) +
|
||||
out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 2, 0) -
|
||||
out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) -
|
||||
out_transform(n, k, htile, wtile, 3, 0) - out_transform(n, k, htile, wtile, 3, 1) -
|
||||
out_transform(n, k, htile, wtile, 3, 2);
|
||||
out_hold(n, k, htile, wtile, 1, 1) =
|
||||
out_transform(n, k, htile, wtile, 1, 1) - out_transform(n, k, htile, wtile, 1, 2) -
|
||||
out_transform(n, k, htile, wtile, 1, 3) - out_transform(n, k, htile, wtile, 2, 1) +
|
||||
out_transform(n, k, htile, wtile, 2, 2) + out_transform(n, k, htile, wtile, 2, 3) -
|
||||
out_transform(n, k, htile, wtile, 3, 1) + out_transform(n, k, htile, wtile, 3, 2) +
|
||||
out_transform(n, k, htile, wtile, 3, 3);
|
||||
};
|
||||
|
||||
auto f_out = [&](auto n, auto k, auto y, auto x) {
|
||||
for(int j = 0; j < OutTileSizeH; ++j)
|
||||
auto f_out = [&](auto n, auto k, auto htile, auto wtile) {
|
||||
for(int j = 0; j < HoPerTile; ++j)
|
||||
{
|
||||
std::size_t ho = OutTileSizeH * y + j;
|
||||
for(int i = 0; i < OutTileSizeW; ++i)
|
||||
std::size_t ho = HoPerTile * htile + j;
|
||||
for(int i = 0; i < WoPerTile; ++i)
|
||||
{
|
||||
std::size_t wo = OutTileSizeW * x + i;
|
||||
out(n, k, ho, wo) = out_hold(n, k, y, x, j, i);
|
||||
std::size_t wo = WoPerTile * wtile + i;
|
||||
out(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
make_ParallelTensorFunctor(f_in_hold, N, C, Y, X)(num_thread);
|
||||
make_ParallelTensorFunctor(f_in_transform, N, C, Y, X)(num_thread);
|
||||
make_ParallelTensorFunctor(f_in_hold, N, C, HTile, WTile)(num_thread);
|
||||
make_ParallelTensorFunctor(f_in_transform, N, C, HTile, WTile)(num_thread);
|
||||
make_ParallelTensorFunctor(f_wei_transform, K, C)(num_thread);
|
||||
make_ParallelTensorFunctor(f_out_transform, N, K, Y, X)(num_thread);
|
||||
make_ParallelTensorFunctor(f_out_hold, N, K, Y, X)(num_thread);
|
||||
make_ParallelTensorFunctor(f_out, N, K, Y, X)(num_thread);
|
||||
make_ParallelTensorFunctor(f_out_transform, N, K, HTile, WTile)(num_thread);
|
||||
make_ParallelTensorFunctor(f_out_hold, N, K, HTile, WTile)(num_thread);
|
||||
make_ParallelTensorFunctor(f_out, N, K, HTile, WTile)(num_thread);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
@@ -387,8 +399,8 @@ int main(int argc, char* argv[])
|
||||
constexpr unsigned HI = 28;
|
||||
constexpr unsigned WI = 28;
|
||||
constexpr unsigned K = 1;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
@@ -399,8 +411,8 @@ int main(int argc, char* argv[])
|
||||
constexpr unsigned HI = 34;
|
||||
constexpr unsigned WI = 34;
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
@@ -411,8 +423,8 @@ int main(int argc, char* argv[])
|
||||
constexpr unsigned HI = 56;
|
||||
constexpr unsigned WI = 56;
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
#elif 0
|
||||
// 3x3, 58x58
|
||||
constexpr unsigned N = 64;
|
||||
@@ -420,8 +432,8 @@ int main(int argc, char* argv[])
|
||||
constexpr unsigned HI = 58;
|
||||
constexpr unsigned WI = 58;
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
#elif 0
|
||||
// 5x5, 36x36
|
||||
constexpr unsigned N = 64;
|
||||
@@ -429,8 +441,8 @@ int main(int argc, char* argv[])
|
||||
constexpr unsigned HI = 36;
|
||||
constexpr unsigned WI = 36;
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned S = 5;
|
||||
constexpr unsigned R = 5;
|
||||
constexpr unsigned Y = 5;
|
||||
constexpr unsigned X = 5;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
@@ -441,8 +453,8 @@ int main(int argc, char* argv[])
|
||||
constexpr unsigned HI = 38;
|
||||
constexpr unsigned WI = 38;
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned S = 7;
|
||||
constexpr unsigned R = 7;
|
||||
constexpr unsigned Y = 7;
|
||||
constexpr unsigned X = 7;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
@@ -453,8 +465,8 @@ int main(int argc, char* argv[])
|
||||
constexpr unsigned HI = 58;
|
||||
constexpr unsigned WI = 58;
|
||||
constexpr unsigned K = 256;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
#elif 0
|
||||
// 3x3 filter, 58x58 image, 0x0 padding
|
||||
constexpr unsigned N = 16;
|
||||
@@ -462,8 +474,8 @@ int main(int argc, char* argv[])
|
||||
constexpr unsigned HI = 58;
|
||||
constexpr unsigned WI = 58;
|
||||
constexpr unsigned K = 256;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
@@ -474,8 +486,8 @@ int main(int argc, char* argv[])
|
||||
constexpr unsigned HI = 56;
|
||||
constexpr unsigned WI = 56;
|
||||
constexpr unsigned K = 256;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
|
||||
constexpr unsigned HPad = 1;
|
||||
constexpr unsigned WPad = 1;
|
||||
@@ -486,8 +498,8 @@ int main(int argc, char* argv[])
|
||||
constexpr unsigned HI = 28;
|
||||
constexpr unsigned WI = 28;
|
||||
constexpr unsigned K = 512;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
|
||||
constexpr unsigned HPad = 1;
|
||||
constexpr unsigned WPad = 1;
|
||||
@@ -498,8 +510,8 @@ int main(int argc, char* argv[])
|
||||
constexpr unsigned HI = 28;
|
||||
constexpr unsigned WI = 28;
|
||||
constexpr unsigned K = 512;
|
||||
constexpr unsigned S = 1;
|
||||
constexpr unsigned R = 1;
|
||||
constexpr unsigned Y = 1;
|
||||
constexpr unsigned X = 1;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
@@ -510,8 +522,8 @@ int main(int argc, char* argv[])
|
||||
constexpr unsigned HI = 20;
|
||||
constexpr unsigned WI = 84;
|
||||
constexpr unsigned K = 256;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
|
||||
constexpr unsigned HPad = 1;
|
||||
constexpr unsigned WPad = 1;
|
||||
@@ -522,8 +534,8 @@ int main(int argc, char* argv[])
|
||||
constexpr unsigned HI = 112;
|
||||
constexpr unsigned WI = 112;
|
||||
constexpr unsigned K = 128;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
|
||||
constexpr unsigned HPad = 1;
|
||||
constexpr unsigned WPad = 1;
|
||||
@@ -534,8 +546,8 @@ int main(int argc, char* argv[])
|
||||
constexpr unsigned HI = 20;
|
||||
constexpr unsigned WI = 86;
|
||||
constexpr unsigned K = 512;
|
||||
constexpr unsigned S = 5;
|
||||
constexpr unsigned R = 5;
|
||||
constexpr unsigned Y = 5;
|
||||
constexpr unsigned X = 5;
|
||||
|
||||
constexpr unsigned HPad = 1;
|
||||
constexpr unsigned WPad = 1;
|
||||
@@ -546,8 +558,8 @@ int main(int argc, char* argv[])
|
||||
constexpr unsigned HI = 28;
|
||||
constexpr unsigned WI = 28;
|
||||
constexpr unsigned K = 32;
|
||||
constexpr unsigned S = 5;
|
||||
constexpr unsigned R = 5;
|
||||
constexpr unsigned Y = 5;
|
||||
constexpr unsigned X = 5;
|
||||
|
||||
constexpr unsigned HPad = 2;
|
||||
constexpr unsigned WPad = 2;
|
||||
@@ -557,7 +569,7 @@ int main(int argc, char* argv[])
|
||||
auto upper_pads = Sequence<HPad, WPad>{};
|
||||
|
||||
auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence<N, C, HI, WI>{});
|
||||
auto wei_kcsr_desc = make_ConstantTensorDescriptor(Sequence<K, C, S, R>{});
|
||||
auto wei_kcsr_desc = make_ConstantTensorDescriptor(Sequence<K, C, Y, X>{});
|
||||
auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor(
|
||||
in_nchw_desc, wei_kcsr_desc, lower_pads, upper_pads);
|
||||
|
||||
@@ -600,14 +612,8 @@ int main(int argc, char* argv[])
|
||||
device_direct_convolution_1
|
||||
#elif 0
|
||||
device_direct_convolution_2
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_1_nchw_kcsr_nkhw
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_1_nchw_srck_nkhw
|
||||
#elif 1
|
||||
device_implicit_gemm_convolution_1_chwn_csrk_khwn
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_2_cnhw_csrk_knhw
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_2_chwn_csrk_khwn
|
||||
#endif
|
||||
@@ -627,7 +633,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
if(S == 3 && R == 3)
|
||||
if(Y == 3 && X == 3)
|
||||
{
|
||||
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user