mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
@@ -40,11 +40,11 @@ struct GeneratorTensor_Checkboard
|
||||
template <class... Ts>
|
||||
double operator()(Ts... Xs) const
|
||||
{
|
||||
std::array<unsigned long, sizeof...(Ts)> dims = {{Xs...}};
|
||||
std::array<index_t, sizeof...(Ts)> dims = {{Xs...}};
|
||||
return std::accumulate(dims.begin(),
|
||||
dims.end(),
|
||||
true,
|
||||
[](bool init, unsigned long x) -> int { return init != (x % 2); })
|
||||
[](bool init, index_t x) -> int { return init != (x % 2); })
|
||||
? 1
|
||||
: -1;
|
||||
}
|
||||
@@ -80,9 +80,9 @@ auto make_TensorDescriptor(TConstTensorDesc)
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto desc = TConstTensorDesc{};
|
||||
|
||||
std::initializer_list<unsigned> lengths = {
|
||||
std::initializer_list<index_t> lengths = {
|
||||
desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3)};
|
||||
std::initializer_list<unsigned> strides = {
|
||||
std::initializer_list<index_t> strides = {
|
||||
desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3)};
|
||||
|
||||
return TensorDescriptor(lengths, strides);
|
||||
@@ -95,11 +95,11 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
|
||||
LowerPads,
|
||||
UpperPads)
|
||||
{
|
||||
unsigned h_pad_low = LowerPads{}.Get(Number<0>{});
|
||||
unsigned w_pad_low = LowerPads{}.Get(Number<1>{});
|
||||
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
|
||||
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
|
||||
|
||||
unsigned h_pad_up = UpperPads{}.Get(Number<0>{});
|
||||
unsigned w_pad_up = UpperPads{}.Get(Number<1>{});
|
||||
index_t h_pad_up = UpperPads{}.Get(Number<0>{});
|
||||
index_t w_pad_up = UpperPads{}.Get(Number<1>{});
|
||||
|
||||
auto f = [&](auto n, auto k, auto ho, auto wo) {
|
||||
double v = 0;
|
||||
@@ -153,11 +153,11 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
|
||||
std::size_t HO = out_nkhw.mDesc.GetLengths()[2];
|
||||
std::size_t WO = out_nkhw.mDesc.GetLengths()[3];
|
||||
|
||||
unsigned h_pad_low = LowerPads{}.Get(Number<0>{});
|
||||
unsigned w_pad_low = LowerPads{}.Get(Number<1>{});
|
||||
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
|
||||
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
|
||||
|
||||
unsigned h_pad_up = UpperPads{}.Get(Number<0>{});
|
||||
unsigned w_pad_up = UpperPads{}.Get(Number<1>{});
|
||||
index_t h_pad_up = UpperPads{}.Get(Number<0>{});
|
||||
index_t w_pad_up = UpperPads{}.Get(Number<1>{});
|
||||
|
||||
std::size_t HiPerTile = HoPerTile + Y - 1;
|
||||
std::size_t WiPerTile = WoPerTile + X - 1;
|
||||
@@ -399,211 +399,211 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if 0
|
||||
constexpr unsigned N = 1;
|
||||
constexpr unsigned C = 1;
|
||||
constexpr unsigned HI = 28;
|
||||
constexpr unsigned WI = 28;
|
||||
constexpr unsigned K = 1;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 1;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 1;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 3x3, 34x34
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 34;
|
||||
constexpr unsigned WI = 34;
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 34;
|
||||
constexpr index_t WI = 34;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 3x3, 56x56
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 64;
|
||||
constexpr unsigned HI = 56;
|
||||
constexpr unsigned WI = 56;
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
#elif 0
|
||||
// 3x3, 58x58
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 64;
|
||||
constexpr unsigned HI = 58;
|
||||
constexpr unsigned WI = 58;
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 58;
|
||||
constexpr index_t WI = 58;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
#elif 0
|
||||
// 5x5, 36x36
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 36;
|
||||
constexpr unsigned WI = 36;
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned Y = 5;
|
||||
constexpr unsigned X = 5;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 36;
|
||||
constexpr index_t WI = 36;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 5;
|
||||
constexpr index_t X = 5;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 7x7, 38x38
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 38;
|
||||
constexpr unsigned WI = 38;
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned Y = 7;
|
||||
constexpr unsigned X = 7;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 38;
|
||||
constexpr index_t WI = 38;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 7;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 3x3, 58x58
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 128;
|
||||
constexpr unsigned HI = 58;
|
||||
constexpr unsigned WI = 58;
|
||||
constexpr unsigned K = 256;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 58;
|
||||
constexpr index_t WI = 58;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
#elif 0
|
||||
// 3x3 filter, 58x58 image, 0x0 padding
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 128;
|
||||
constexpr unsigned HI = 58;
|
||||
constexpr unsigned WI = 58;
|
||||
constexpr unsigned K = 256;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 58;
|
||||
constexpr index_t WI = 58;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 3x3 filter, 56x56 image, 1x1 padding
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 128;
|
||||
constexpr unsigned HI = 56;
|
||||
constexpr unsigned WI = 56;
|
||||
constexpr unsigned K = 256;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr unsigned HPad = 1;
|
||||
constexpr unsigned WPad = 1;
|
||||
constexpr index_t HPad = 1;
|
||||
constexpr index_t WPad = 1;
|
||||
#elif 0
|
||||
// 3x3 filter, 28x28 image, 1x1 padding
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 28;
|
||||
constexpr unsigned WI = 28;
|
||||
constexpr unsigned K = 512;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr unsigned HPad = 1;
|
||||
constexpr unsigned WPad = 1;
|
||||
constexpr index_t HPad = 1;
|
||||
constexpr index_t WPad = 1;
|
||||
#elif 0
|
||||
// 1x1 filter, 28x28 image
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 28;
|
||||
constexpr unsigned WI = 28;
|
||||
constexpr unsigned K = 512;
|
||||
constexpr unsigned Y = 1;
|
||||
constexpr unsigned X = 1;
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 3x3 filter, 20x84 image, 1x1 padding
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 20;
|
||||
constexpr unsigned WI = 84;
|
||||
constexpr unsigned K = 256;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 20;
|
||||
constexpr index_t WI = 84;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr unsigned HPad = 1;
|
||||
constexpr unsigned WPad = 1;
|
||||
constexpr index_t HPad = 1;
|
||||
constexpr index_t WPad = 1;
|
||||
#elif 0
|
||||
// 3x3 filter, 112x112 image, 1x1 padding
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 64;
|
||||
constexpr unsigned HI = 112;
|
||||
constexpr unsigned WI = 112;
|
||||
constexpr unsigned K = 128;
|
||||
constexpr unsigned Y = 3;
|
||||
constexpr unsigned X = 3;
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 112;
|
||||
constexpr index_t WI = 112;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr unsigned HPad = 1;
|
||||
constexpr unsigned WPad = 1;
|
||||
constexpr index_t HPad = 1;
|
||||
constexpr index_t WPad = 1;
|
||||
#elif 0
|
||||
// 5x5 filter, 20x86 image, 1x1 padding
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 20;
|
||||
constexpr unsigned WI = 86;
|
||||
constexpr unsigned K = 512;
|
||||
constexpr unsigned Y = 5;
|
||||
constexpr unsigned X = 5;
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 20;
|
||||
constexpr index_t WI = 86;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 5;
|
||||
constexpr index_t X = 5;
|
||||
|
||||
constexpr unsigned HPad = 1;
|
||||
constexpr unsigned WPad = 1;
|
||||
constexpr index_t HPad = 1;
|
||||
constexpr index_t WPad = 1;
|
||||
#elif 0
|
||||
// 5x5 filter, 28x28 image, 2x2 padding
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 192;
|
||||
constexpr unsigned HI = 28;
|
||||
constexpr unsigned WI = 28;
|
||||
constexpr unsigned K = 32;
|
||||
constexpr unsigned Y = 5;
|
||||
constexpr unsigned X = 5;
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 32;
|
||||
constexpr index_t Y = 5;
|
||||
constexpr index_t X = 5;
|
||||
|
||||
constexpr unsigned HPad = 2;
|
||||
constexpr unsigned WPad = 2;
|
||||
constexpr index_t HPad = 2;
|
||||
constexpr index_t WPad = 2;
|
||||
#elif 0
|
||||
// 1x1 filter, 32x32 image
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 32;
|
||||
constexpr unsigned WI = 32;
|
||||
constexpr unsigned K = 512;
|
||||
constexpr unsigned Y = 1;
|
||||
constexpr unsigned X = 1;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 32;
|
||||
constexpr index_t WI = 32;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 1x1 filter, 14x14 image
|
||||
constexpr unsigned N = 128;
|
||||
constexpr unsigned C = 2048;
|
||||
constexpr unsigned HI = 14;
|
||||
constexpr unsigned WI = 14;
|
||||
constexpr unsigned K = 512;
|
||||
constexpr unsigned Y = 1;
|
||||
constexpr unsigned X = 1;
|
||||
// 1x1 filter, 14x14 image, C = 2048
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 2048;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 1
|
||||
// 1x1 filter, 14x14 image, C = 512
|
||||
constexpr unsigned N = 128;
|
||||
constexpr unsigned C = 512;
|
||||
constexpr unsigned HI = 14;
|
||||
constexpr unsigned WI = 14;
|
||||
constexpr unsigned K = 512;
|
||||
constexpr unsigned Y = 1;
|
||||
constexpr unsigned X = 1;
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#endif
|
||||
|
||||
auto lower_pads = Sequence<HPad, WPad>{};
|
||||
@@ -634,7 +634,7 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
bool do_verification = atoi(argv[1]);
|
||||
unsigned nrepeat = atoi(argv[2]);
|
||||
index_t nrepeat = atoi(argv[2]);
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user