Clean up conv example, Instances, profiler and test (#324)

* convnd_fwd fp16 example

* update example

* update example

* update instance

* updating refernce conv

* update reference conv

* update conv fwd profiler

* update conv 1d and 3d instance

* update include path

* clean

* update profiler for conv bwd data and weight

* update conv bwd weight

* clean

* update conv example

* update profiler for conv bwd weight

* update ckprofiler for conv bwd data

* fix reference conv bwd data bug; update conv bwd data test

* update examples

* fix initialization issue

* update test for conv fwd

* clean

* clean

* remove test case too sensitive to error threshhold

* fix test

* clean

* fix build

* adding conv multiple d

* adding conv multiple D

* add matrix padder

* add gemm padding to convnd

* adding group conv

* update gemm multi-d

* refactor

* refactor

* refactor

* clean

* clean

* refactor

* refactor

* reorg

* add ds

* add bias

* clean

* add G

* adding group

* adding group

* adding group

* update Tensor

* clean

* update example

* update DeviceGemmMultipleD_Xdl_CShuffle

* update conv bwd-data and bwd-weight

* upate contraction example

* update gemm and batch gemm with e permute

* fix example build

* instance for grouped conv1d

* update example

* adding group conv instance

* update gemm bilinear instance

* update gemm+add+add+fastgelu instance

* update profiler

* update profiler

* update test

* update test and client example

* clean

* add grouped conv into profiler

* update profiler

* clean

* add test grouped conv, update all conv test to gtest

* update test

[ROCm/composable_kernel commit: 500fa99512]
This commit is contained in:
Chao Liu
2022-07-29 18:19:25 -05:00
committed by GitHub
parent 1450273dc5
commit 236f946292
373 changed files with 17544 additions and 17013 deletions

View File

@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {

View File

@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {

View File

@@ -8,22 +8,24 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X]
template <typename InDataType,
// input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename AccDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2,
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdData : public device::BaseOperator
{
// Argument
@@ -73,36 +75,45 @@ struct ReferenceConvBwdData : public device::BaseOperator
float Run(const Argument& arg)
{
if constexpr(NumDimSpatial == 1)
if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == NDimSpatial + 3))
{
auto f_ncw = [&](auto n, auto c, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
std::size_t X = arg.weight_.mDesc.GetLengths()[2];
std::size_t Wo = arg.output_.mDesc.GetLengths()[2];
throw std::runtime_error("wrong! inconsistent dimension");
}
AccDataType v_acc = 0;
if constexpr(NDimSpatial == 1)
{
auto f_ncw = [&](auto g, auto n, auto c, auto wi) {
std::size_t K = arg.weight_.GetLengths()[1];
std::size_t X = arg.weight_.GetLengths()[3];
std::size_t Wo = arg.output_.GetLengths()[3];
float v_acc = 0;
for(std::size_t x = 0; x < X; ++x)
{
auto w_tmp = ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]);
auto w_tmp = static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]);
if(w_tmp % arg.conv_strides_[0] == 0)
{
auto wo = ck::type_convert<ck::long_index_t>(w_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[0]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_out = 0;
AccDataType v_wei = 0;
float v_out = 0;
float v_wei = 0;
arg.out_element_op_(
v_out,
ck::type_convert<AccDataType>(arg.output_(n, k, wo)));
v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
arg.wei_element_op_(
v_wei, ck::type_convert<AccDataType>(arg.weight_(k, c, x)));
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x)));
v_acc += v_out * v_wei;
}
@@ -110,66 +121,72 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
arg.in_element_op_(v_acc, v_acc);
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_acc);
float v_in;
arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, wi) = ck::type_convert<InDataType>(v_acc);
};
make_ParallelTensorFunctor(f_ncw,
arg.input_.mDesc.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2])(
arg.input_.GetLengths()[0],
arg.input_.GetLengths()[1],
arg.input_.GetLengths()[2],
arg.input_.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 2)
else if constexpr(NDimSpatial == 2)
{
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
std::size_t Y = arg.weight_.mDesc.GetLengths()[2];
std::size_t X = arg.weight_.mDesc.GetLengths()[3];
auto f_nchw = [&](auto g, auto n, auto c, auto hi, auto wi) {
std::size_t K = arg.weight_.GetLengths()[1];
std::size_t Y = arg.weight_.GetLengths()[3];
std::size_t X = arg.weight_.GetLengths()[4];
std::size_t Ho = arg.output_.mDesc.GetLengths()[2];
std::size_t Wo = arg.output_.mDesc.GetLengths()[3];
std::size_t Ho = arg.output_.GetLengths()[3];
std::size_t Wo = arg.output_.GetLengths()[4];
AccDataType v_acc = 0;
float v_acc = 0;
for(std::size_t y = 0; y < Y; ++y)
{
auto h_tmp = ck::type_convert<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]);
auto h_tmp = static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]);
if(h_tmp % arg.conv_strides_[0] == 0)
{
auto ho = ck::type_convert<ck::long_index_t>(h_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
auto ho = static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[0]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{
for(std::size_t x = 0; x < X; ++x)
{
auto w_tmp =
ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(x *
arg.conv_dilations_[1]);
static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]);
if(w_tmp % arg.conv_strides_[1] == 0)
{
auto wo = ck::type_convert<ck::long_index_t>(w_tmp) /
ck::type_convert<ck::long_index_t>(
arg.conv_strides_[1]);
auto wo =
static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_out = 0;
AccDataType v_wei = 0;
float v_out = 0;
float v_wei = 0;
arg.out_element_op_(v_out,
ck::type_convert<AccDataType>(
arg.output_(n, k, ho, wo)));
arg.wei_element_op_(v_wei,
ck::type_convert<AccDataType>(
arg.weight_(k, c, y, x)));
arg.out_element_op_(
v_out,
ck::type_convert<float>(
arg.output_(g, n, k, ho, wo)));
arg.wei_element_op_(
v_wei,
ck::type_convert<float>(
arg.weight_(g, k, c, y, x)));
v_acc += v_out * v_wei;
}
@@ -180,90 +197,91 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
AccDataType v_in;
float v_in;
arg.in_element_op_(v_in, v_acc);
arg.input_(n, c, hi, wi) = ck::type_convert<InDataType>(v_in);
arg.input_(g, n, c, hi, wi) = ck::type_convert<InDataType>(v_acc);
};
make_ParallelTensorFunctor(f_nchw,
arg.input_.mDesc.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2],
arg.input_.mDesc.GetLengths()[3])(
arg.input_.GetLengths()[0],
arg.input_.GetLengths()[1],
arg.input_.GetLengths()[2],
arg.input_.GetLengths()[3],
arg.input_.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 3)
else if constexpr(NDimSpatial == 3)
{
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
std::size_t Z = arg.weight_.mDesc.GetLengths()[2];
std::size_t Y = arg.weight_.mDesc.GetLengths()[3];
std::size_t X = arg.weight_.mDesc.GetLengths()[4];
auto f_ncdhw = [&](auto g, auto n, auto c, auto di, auto hi, auto wi) {
std::size_t K = arg.weight_.GetLengths()[1];
std::size_t Z = arg.weight_.GetLengths()[3];
std::size_t Y = arg.weight_.GetLengths()[4];
std::size_t X = arg.weight_.GetLengths()[5];
std::size_t Do = arg.output_.mDesc.GetLengths()[2];
std::size_t Ho = arg.output_.mDesc.GetLengths()[3];
std::size_t Wo = arg.output_.mDesc.GetLengths()[4];
std::size_t Do = arg.output_.GetLengths()[3];
std::size_t Ho = arg.output_.GetLengths()[4];
std::size_t Wo = arg.output_.GetLengths()[5];
AccDataType v_acc = 0;
float v_acc = 0;
for(std::size_t z = 0; z < Z; ++z)
{
auto d_tmp = ck::type_convert<ck::long_index_t>(di) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]);
auto d_tmp = static_cast<ck::long_index_t>(di) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]);
if(d_tmp % arg.conv_strides_[0] == 0)
{
auto do_ = ck::type_convert<ck::long_index_t>(d_tmp) /
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
auto do_ = static_cast<ck::long_index_t>(d_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[0]);
if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
{
for(std::size_t y = 0; y < Y; ++y)
{
auto h_tmp =
ck::type_convert<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(y *
arg.conv_dilations_[1]);
static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]);
if(h_tmp % arg.conv_strides_[1] == 0)
{
auto ho = ck::type_convert<ck::long_index_t>(h_tmp) /
ck::type_convert<ck::long_index_t>(
arg.conv_strides_[1]);
auto ho =
static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{
for(std::size_t x = 0; x < X; ++x)
{
auto w_tmp =
ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(
arg.in_left_pads_[2]) -
ck::type_convert<ck::long_index_t>(
x * arg.conv_dilations_[2]);
auto w_tmp = static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(
arg.in_left_pads_[2]) -
static_cast<ck::long_index_t>(
x * arg.conv_dilations_[2]);
if(w_tmp % arg.conv_strides_[2] == 0)
{
auto wo =
ck::type_convert<ck::long_index_t>(w_tmp) /
ck::type_convert<ck::long_index_t>(
arg.conv_strides_[2]);
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(
arg.conv_strides_[2]);
if(wo >= 0 &&
ck::type_convert<std::size_t>(wo) < Wo)
{
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_out = 0;
AccDataType v_wei = 0;
float v_out = 0;
float v_wei = 0;
arg.out_element_op_(
v_out,
ck::type_convert<AccDataType>(
arg.output_(
n, k, do_, ho, wo)));
ck::type_convert<float>(arg.output_(
g, n, k, do_, ho, wo)));
arg.wei_element_op_(
v_wei,
ck::type_convert<AccDataType>(
arg.weight_(k, c, z, y, x)));
ck::type_convert<float>(
arg.weight_(g, k, c, z, y, x)));
v_acc += v_out * v_wei;
}
@@ -277,17 +295,20 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
AccDataType v_in;
float v_in;
arg.in_element_op_(v_in, v_acc);
arg.input_(n, c, di, hi, wi) = ck::type_convert<InDataType>(v_in);
arg.input_(g, n, c, di, hi, wi) = ck::type_convert<InDataType>(v_acc);
};
make_ParallelTensorFunctor(f_ncdhw,
arg.input_.mDesc.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2],
arg.input_.mDesc.GetLengths()[3],
arg.input_.mDesc.GetLengths()[4])(
arg.input_.GetLengths()[0],
arg.input_.GetLengths()[1],
arg.input_.GetLengths()[2],
arg.input_.GetLengths()[3],
arg.input_.GetLengths()[4],
arg.input_.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;

View File

@@ -7,21 +7,25 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X]
template <typename InDataType,
// input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2,
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdWeight : public device::BaseOperator
{
// Argument
@@ -71,156 +75,162 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
float Run(const Argument& arg)
{
if constexpr(NumDimSpatial == 1)
if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == NDimSpatial + 3))
{
constexpr auto I0 = Number<0>{};
auto f_kcx = [&](auto k, auto c, auto x) {
throw std::runtime_error("wrong! inconsistent dimension");
}
if constexpr(NDimSpatial == 1)
{
auto f_kcx = [&](auto g, auto k, auto c, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
{
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[2]; ++wo)
for(std::size_t wo = 0; wo < arg.output_.GetLengths()[3]; ++wo)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I0]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2])
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{
float v_out;
float v_in;
arg.out_element_op_(v_out,
ck::type_convert<float>(arg.output_(n, k, wo)));
arg.in_element_op_(v_in,
ck::type_convert<float>(arg.input_(n, c, wi)));
arg.out_element_op_(
v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
v_acc += v_out * v_in;
}
}
}
float v_wei;
arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, x) = ck::type_convert<WeiDataType>(v_wei);
arg.weight_(g, k, c, x) = ck::type_convert<WeiDataType>(v_wei);
};
make_ParallelTensorFunctor(f_kcx,
arg.weight_.mDesc.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2])(
arg.weight_.GetLengths()[0],
arg.weight_.GetLengths()[1],
arg.weight_.GetLengths()[2],
arg.weight_.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 2)
else if constexpr(NDimSpatial == 2)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
auto f_kcyx = [&](auto g, auto k, auto c, auto y, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
{
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[2]; ++ho)
for(std::size_t ho = 0; ho < arg.output_.GetLengths()[3]; ++ho)
{
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I0]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[3]; ++wo)
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t wo = 0; wo < arg.output_.GetLengths()[4]; ++wo)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I1]) +
ck::type_convert<ck::long_index_t>(x *
arg.conv_dilations_[I1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[2] &&
ck::type_convert<std::size_t>(hi) < arg.input_.GetLengths()[3] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[3])
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{
float v_out;
float v_in;
arg.out_element_op_(
v_out, ck::type_convert<float>(arg.output_(n, k, ho, wo)));
v_out,
ck::type_convert<float>(arg.output_(g, n, k, ho, wo)));
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi)));
v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
v_acc += v_out * v_in;
}
}
}
}
float v_wei;
arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, y, x) = ck::type_convert<WeiDataType>(v_wei);
arg.weight_(g, k, c, y, x) = ck::type_convert<WeiDataType>(v_wei);
};
make_ParallelTensorFunctor(f_kcyx,
arg.weight_.mDesc.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3])(
arg.weight_.GetLengths()[0],
arg.weight_.GetLengths()[1],
arg.weight_.GetLengths()[2],
arg.weight_.GetLengths()[3],
arg.weight_.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 3)
else if constexpr(NDimSpatial == 3)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
auto f_kczyx = [&](auto k, auto c, auto z, auto y, auto x) {
auto f_kczyx = [&](auto g, auto k, auto c, auto z, auto y, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
{
for(std::size_t do_ = 0; do_ < arg.output_.mDesc.GetLengths()[2]; ++do_)
for(std::size_t do_ = 0; do_ < arg.output_.GetLengths()[3]; ++do_)
{
auto di =
ck::type_convert<ck::long_index_t>(do_ * arg.conv_strides_[I0]) +
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[3]; ++ho)
auto di = static_cast<ck::long_index_t>(do_ * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t ho = 0; ho < arg.output_.GetLengths()[4]; ++ho)
{
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I1]) +
ck::type_convert<ck::long_index_t>(y *
arg.conv_dilations_[I1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[4];
++wo)
static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(std::size_t wo = 0; wo < arg.output_.GetLengths()[5]; ++wo)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo *
arg.conv_strides_[I2]) +
ck::type_convert<ck::long_index_t>(
x * arg.conv_dilations_[I2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I2]);
static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] &&
arg.input_.GetLengths()[3] &&
hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] &&
arg.input_.GetLengths()[4] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4])
arg.input_.GetLengths()[5])
{
float v_out;
float v_in;
arg.out_element_op_(v_out,
ck::type_convert<float>(
arg.output_(n, k, do_, ho, wo)));
arg.in_element_op_(
v_in,
ck::type_convert<float>(arg.input_(n, c, di, hi, wi)));
arg.output_(g, n, k, do_, ho, wo)));
arg.in_element_op_(v_in,
ck::type_convert<float>(
arg.input_(g, n, c, di, hi, wi)));
v_acc += v_out * v_in;
}
@@ -228,19 +238,21 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
}
}
}
float v_wei;
arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, z, y, x) = ck::type_convert<WeiDataType>(v_wei);
arg.weight_(g, k, c, z, y, x) = ck::type_convert<WeiDataType>(v_wei);
};
make_ParallelTensorFunctor(f_kczyx,
arg.weight_.mDesc.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3],
arg.weight_.mDesc.GetLengths()[4])(
arg.weight_.GetLengths()[0],
arg.weight_.GetLengths()[1],
arg.weight_.GetLengths()[2],
arg.weight_.GetLengths()[3],
arg.weight_.GetLengths()[4],
arg.weight_.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;

View File

@@ -8,7 +8,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
@@ -17,9 +17,10 @@ namespace host {
//
// @brief Reference implementation for forward convolution.
//
// @paragraph Supports both NCHW as well as NHWC formats (and their respective
// counterparts for weight and output) as long as tensor descriptor
// lengths is in NCHW.
// @paragraph
// Tensor descriptor in GNCHW/GKCXY/GNKHW dimensional order
// Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout
// as long as dimensions in tensor descriptor is in GNCHW order
//
// @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type.
@@ -28,16 +29,20 @@ namespace host {
// operation.
// @tparam WeiElementwiseOperation Functor for weights tensor elementwise
// operation.
// @tparam NumDimSpatial Number of spatial dimensions.
// @tparam NDimSpatial Number of spatial dimensions.
//
template <typename InDataType,
// input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2,
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceConvFwd : public device::BaseOperator
{
// Argument
@@ -86,29 +91,37 @@ struct ReferenceConvFwd : public device::BaseOperator
float Run(const Argument& arg)
{
if constexpr(NumDimSpatial == 1)
if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == NDimSpatial + 3))
{
auto f_ncw = [&](auto n, auto k, auto wo) {
throw std::runtime_error("wrong! inconsistent dimension");
}
if constexpr(NDimSpatial == 1)
{
auto func = [&](auto g, auto n, auto k, auto wo) {
float v_acc = 0;
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x)
for(std::size_t x = 0; x < arg.weight_.GetLengths()[3]; ++x)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2])
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{
float v_in;
float v_wei;
arg.in_element_op_(v_in,
ck::type_convert<float>(arg.input_(n, c, wi)));
arg.wei_element_op_(v_wei,
ck::type_convert<float>(arg.weight_(k, c, x)));
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
arg.wei_element_op_(
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x)));
v_acc += v_in * v_wei;
}
@@ -118,50 +131,53 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out;
arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_(g, n, k, wo) = ck::type_convert<OutDataType>(v_out);
};
make_ParallelTensorFunctor(f_ncw,
arg.output_.mDesc.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2])(
make_ParallelTensorFunctor(func,
arg.output_.GetLengths()[0],
arg.output_.GetLengths()[1],
arg.output_.GetLengths()[2],
arg.output_.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 2)
else if constexpr(NDimSpatial == 2)
{
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
auto func = [&](auto g, auto n, auto k, auto ho, auto wo) {
float v_acc = 0;
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y)
for(std::size_t y = 0; y < arg.weight_.GetLengths()[3]; ++y)
{
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x)
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.weight_.GetLengths()[4]; ++x)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[2] &&
ck::type_convert<std::size_t>(hi) < arg.input_.GetLengths()[3] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[3])
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{
float v_in;
float v_wei;
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi)));
v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
arg.wei_element_op_(
v_wei, ck::type_convert<float>(arg.weight_(k, c, y, x)));
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, y, x)));
v_acc += v_in * v_wei;
}
}
@@ -171,64 +187,65 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out;
arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_(g, n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
};
make_ParallelTensorFunctor(f_nchw,
arg.output_.mDesc.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2],
arg.output_.mDesc.GetLengths()[3])(
make_ParallelTensorFunctor(func,
arg.output_.GetLengths()[0],
arg.output_.GetLengths()[1],
arg.output_.GetLengths()[2],
arg.output_.GetLengths()[3],
arg.output_.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 3)
else if constexpr(NDimSpatial == 3)
{
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) {
auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0;
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{
for(std::size_t z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z)
for(std::size_t z = 0; z < arg.weight_.GetLengths()[3]; ++z)
{
auto di =
ck::type_convert<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y)
auto di = static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t y = 0; y < arg.weight_.GetLengths()[4]; ++y)
{
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[1]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x)
static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(std::size_t x = 0; x < arg.weight_.GetLengths()[5]; ++x)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo *
arg.conv_strides_[2]) +
ck::type_convert<ck::long_index_t>(x *
arg.conv_dilations_[2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[2]);
static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] &&
arg.input_.GetLengths()[3] &&
hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] &&
arg.input_.GetLengths()[4] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4])
arg.input_.GetLengths()[5])
{
float v_in;
float v_wei;
arg.in_element_op_(
v_in,
ck::type_convert<float>(arg.input_(n, c, di, hi, wi)));
arg.in_element_op_(v_in,
ck::type_convert<float>(
arg.input_(g, n, c, di, hi, wi)));
arg.wei_element_op_(
v_wei,
ck::type_convert<float>(arg.weight_(k, c, z, y, x)));
ck::type_convert<float>(arg.weight_(g, k, c, z, y, x)));
v_acc += v_in * v_wei;
}
}
@@ -239,15 +256,17 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out;
arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_(g, n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
};
make_ParallelTensorFunctor(f_nchw,
arg.output_.mDesc.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2],
arg.output_.mDesc.GetLengths()[3],
arg.output_.mDesc.GetLengths()[4])(
make_ParallelTensorFunctor(func,
arg.output_.GetLengths()[0],
arg.output_.GetLengths()[1],
arg.output_.GetLengths()[2],
arg.output_.GetLengths()[3],
arg.output_.GetLengths()[4],
arg.output_.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
@@ -267,7 +286,10 @@ struct ReferenceConvFwd : public device::BaseOperator
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
bool IsSupportedArgument(const device::BaseArgument*) override
{
return NDimSpatial >= 1 && NDimSpatial <= 3;
}
static auto MakeArgument(const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,

View File

@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {

View File

@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {

View File

@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {

View File

@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {

View File

@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {

View File

@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {

View File

@@ -9,8 +9,8 @@
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace ck {
namespace tensor_operation {

View File

@@ -9,8 +9,8 @@
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace ck {
namespace tensor_operation {

View File

@@ -10,22 +10,67 @@ namespace tensor_operation {
namespace device {
namespace instance {
// aliasing, for commonly used type
// aliasing, for commonly used data type
using F64 = double;
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using EMPTY_TUPLE = ck::Tuple<>;
using Empty_Tuple = ck::Tuple<>;
using F16_TUPLE = ck::Tuple<F16>;
using F16_F16_TUPLE = ck::Tuple<F16, F16>;
using F16_Tuple = ck::Tuple<F16>;
using F16_F16_Tuple = ck::Tuple<F16, F16>;
using F32_TUPLE = ck::Tuple<F32>;
using F32_Tuple = ck::Tuple<F32>;
// GEMM layout
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Row_Tuple = ck::Tuple<Row>;
using Row_Row_Tuple = ck::Tuple<Row, Row>;
// Conv layout
//
using NWC = ck::tensor_layout::convolution::NWC;
using NHWC = ck::tensor_layout::convolution::NHWC;
using NDHWC = ck::tensor_layout::convolution::NDHWC;
using KXC = ck::tensor_layout::convolution::KXC;
using KYXC = ck::tensor_layout::convolution::KYXC;
using KZYXC = ck::tensor_layout::convolution::KZYXC;
using NWK = ck::tensor_layout::convolution::NWK;
using NHWK = ck::tensor_layout::convolution::NHWK;
using NDHWK = ck::tensor_layout::convolution::NDHWK;
//
using GNWC = ck::tensor_layout::convolution::GNWC;
using GNHWC = ck::tensor_layout::convolution::GNHWC;
using GNDHWC = ck::tensor_layout::convolution::GNDHWC;
using GKXC = ck::tensor_layout::convolution::GKXC;
using GKYXC = ck::tensor_layout::convolution::GKYXC;
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
using GNWK = ck::tensor_layout::convolution::GNWK;
using GNHWK = ck::tensor_layout::convolution::GNHWK;
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
//
using NWGC = ck::tensor_layout::convolution::NWGC;
using NHWGC = ck::tensor_layout::convolution::NHWGC;
using NDHWGC = ck::tensor_layout::convolution::NDHWGC;
using KXGC = ck::tensor_layout::convolution::KXGC;
using KYXGC = ck::tensor_layout::convolution::KYXGC;
using KZYXGC = ck::tensor_layout::convolution::KZYXGC;
using NWGK = ck::tensor_layout::convolution::NWGK;
using NHWGK = ck::tensor_layout::convolution::NHWGK;
using NDHWGK = ck::tensor_layout::convolution::NDHWGK;
// pointwise functor
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;

View File

@@ -25,7 +25,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn
2,
F32,
F32,
F32_TUPLE,
F32_Tuple,
F32,
PassThrough,
PassThrough,
@@ -37,7 +37,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn
2,
F32,
F32,
F32_TUPLE,
F32_Tuple,
F32,
PassThrough,
PassThrough,
@@ -49,7 +49,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn
2,
F32,
F32,
F32_TUPLE,
F32_Tuple,
F32,
PassThrough,
PassThrough,
@@ -61,7 +61,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn
2,
F32,
F32,
F32_TUPLE,
F32_Tuple,
F32,
PassThrough,
PassThrough,

View File

@@ -25,7 +25,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instanc
2,
F32,
F32,
EMPTY_TUPLE,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
@@ -37,7 +37,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instanc
2,
F32,
F32,
EMPTY_TUPLE,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
@@ -49,7 +49,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instanc
2,
F32,
F32,
EMPTY_TUPLE,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
@@ -61,7 +61,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc
2,
F32,
F32,
EMPTY_TUPLE,
Empty_Tuple,
F32,
PassThrough,
PassThrough,

View File

@@ -0,0 +1,270 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv1d backward data
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<1,
NWC,
KXC,
NWK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(
std::vector<std::unique_ptr<
DeviceConvBwdData<1, NWC, KXC, NWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(
std::vector<std::unique_ptr<
DeviceConvBwdData<1, NWC, KXC, NWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<1,
NWC,
KXC,
NWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv2d backward data
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC,
KYXC,
NHWK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC,
KYXC,
NHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC,
KYXC,
NHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC,
KYXC,
NHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv3d backward data
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3,
NDHWC,
KZYXC,
NDHWK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3,
NDHWC,
KZYXC,
NDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3,
NDHWC,
KZYXC,
NDHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3,
NDHWC,
KZYXC,
NDHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBwdData<
NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceConvBwdData<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, NWC> && is_same_v<WeiLayout, KXC> &&
is_same_v<OutLayout, NWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,230 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv1d backward weight
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
NWC,
KXC,
NWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
NWC,
KXC,
NWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
NWC,
KXC,
NWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv2d backward weight
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
NHWC,
KYXC,
NHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
NHWC,
KYXC,
NHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
NHWC,
KYXC,
NHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv3d backward weight
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
NDHWC,
KZYXC,
NDHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
NDHWC,
KZYXC,
NDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
NDHWC,
KZYXC,
NDHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBwdWeight<
NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceConvBwdWeight<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, NWC> && is_same_v<WeiLayout, KXC> &&
is_same_v<OutLayout, NWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,128 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// conv2d forward
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvFwd<2,
NHWC,
KYXC,
NHWK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvFwd<2,
NHWC,
KYXC,
NHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceConvFwd<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceConvFwd<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -19,49 +19,53 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_TUPLE,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddAddFastGelu>>>&);
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_TUPLE,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddAddFastGelu>>>&);
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_TUPLE,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddAddFastGelu>>>&);
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_TUPLE,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
@@ -70,7 +74,9 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instanc
// GEMM + Add + Add + FastGelu
template <typename ALayout,
typename BLayout,
typename DELayout,
typename D0Layout,
typename D1Layout,
typename ELayout,
typename ADataType,
typename BDataType,
typename D0DataType,
@@ -79,7 +85,8 @@ template <typename ALayout,
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout,
BLayout,
DELayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
@@ -90,7 +97,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,
DELayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
@@ -108,27 +116,31 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<DELayout, Row>)
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<DELayout, Row>)
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<DELayout, Row>)
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<DELayout, Row>)
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
op_ptrs);
}
}

View File

@@ -19,49 +19,53 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_TUPLE,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_TUPLE,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_TUPLE,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_TUPLE,
F16_Tuple,
F16,
PassThrough,
PassThrough,
@@ -70,7 +74,8 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
// GEMM + Bilinear
template <typename ALayout,
typename BLayout,
typename DELayout,
typename DLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DDataType,
@@ -78,7 +83,8 @@ template <typename ALayout,
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout,
BLayout,
DELayout,
ck::Tuple<DLayout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<DDataType>,
@@ -89,7 +95,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,
DELayout,
ck::Tuple<DLayout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<DDataType>,
@@ -106,24 +113,28 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
is_same_v<DDataType, half_t> && is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<DELayout, Row>)
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<DELayout, Row>)
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<DELayout, Row>)
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs);
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<DELayout, Row>)
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
op_ptrs);
}
}

View File

@@ -0,0 +1,352 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv1d forward, GNWC/GKXC/GNWK
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// grouped conv2d forward, NHWGC/KYXGC/NHWGK
void add_device_grouped_conv2d_fwd_xdl_nhwgc_kyxgc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
KYXGC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
NumDimSpatial,
InLayout,
WeiLayout,
Empty_Tuple,
OutLayout,
InDataType,
WeiDataType,
Empty_Tuple,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
Empty_Tuple,
OutLayout,
InDataType,
WeiDataType,
Empty_Tuple,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, GNWC> &&
is_same_v<WeiLayout, GKXC> && is_same_v<OutLayout, GNWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, KYXGC> && is_same_v<OutLayout, NHWGK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
// no instance
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_kyxgc_nhwgk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
// no instance
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
// no instance
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -16,15 +16,14 @@ namespace tensor_operation {
namespace device {
namespace instance {
using DsType = Tuple<>;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
DsType,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
@@ -33,10 +32,11 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
DsType,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
@@ -45,10 +45,11 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
Empty_Tuple,
Row,
F16,
F16,
DsType,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
@@ -57,10 +58,11 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Col,
Empty_Tuple,
Row,
F16,
F16,
DsType,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
@@ -68,18 +70,18 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedGemm<
ALayout,
BLayout,
CLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
DsDataType,
Empty_Tuple,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
@@ -87,10 +89,11 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
using DeviceOp = DeviceGroupedGemm<ALayout,
BLayout,
CLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
DsDataType,
Empty_Tuple,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
@@ -104,22 +107,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}

View File

@@ -13,7 +13,9 @@
#include <type_traits>
#include <vector>
#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace utils {
@@ -194,10 +196,3 @@ check_err(const std::vector<T>& out,
} // namespace utils
} // namespace ck
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
std::copy(std::begin(v), std::end(v), std::ostream_iterator<T>(os, " "));
return os;
}

View File

@@ -1,574 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <functional>
#include <iterator>
#include <numeric>
#include <sstream>
#include <tuple>
#include <type_traits>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/op_instance_engine.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr<element_wise::PassThrough,
element_wise::PassThrough,
element_wise::PassThrough>;
namespace instance {
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace instance
namespace instance {
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace instance
namespace instance {
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace ck {
namespace utils {
namespace conv {
using DeviceConvFwdNoOpPtr =
ck::tensor_operation::device::DeviceConvFwdPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
/**
* @brief Calculate number of FLOPs for Convolution
*
* @param[in] N Batch size.
* @param[in] C Number of input channels.
* @param[in] K Number of output channels.
* @param[in] filter_spatial_lengths Filter spatial dimensions lengths.
* @param[in] output_spatial_lengths Convolution output spatial dimensions
* lengths.
*
* @return The number of flops.
*/
std::size_t get_flops(ck::index_t N,
ck::index_t C,
ck::index_t K,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths);
/**
* @brief Calculate number of bytes read/write by convolution algorithm.
*
* @param[in] N Batch size.
* @param[in] C Number of input channels.
* @param[in] K Number of output channels.
* @param[in] input_spatial_lengths Input spatial dimensions lengths.
* @param[in] filter_spatial_lengths Filter spatial dimensions lengths.
* @param[in] output_spatial_lengths Output spatial dimensions lengths
*
* @tparam InDataType Input tensor data type.
* @tparam WeiDataType Weights tensor data type.
* @tparam OutDataType Output tensor data type.
*
* @return The number of used bytes.
*/
template <typename InDataType = float,
typename WeiDataType = InDataType,
typename OutDataType = InDataType>
std::size_t get_btype(ck::index_t N,
ck::index_t C,
ck::index_t K,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths)
{
// sizeof(InDataType) * (N * C * <input spatial lengths product>) +
// sizeof(WeiDataType) * (K * C * <filter spatial lengths product>) +
// sizeof(OutDataType) * (N * K * <output spatial lengths product>);
return sizeof(InDataType) * (N * C *
std::accumulate(std::begin(input_spatial_lengths),
std::end(input_spatial_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>())) +
sizeof(WeiDataType) * (K * C *
std::accumulate(std::begin(filter_spatial_lengths),
std::end(filter_spatial_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>())) +
sizeof(OutDataType) * (N * K *
std::accumulate(std::begin(output_spatial_lengths),
std::end(output_spatial_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>()));
}
struct ConvParams
{
ConvParams();
ConvParams(ck::index_t n_dim,
ck::index_t n_batch,
ck::index_t n_out_channels,
ck::index_t n_in_channels,
const std::vector<ck::index_t>& filters_len,
const std::vector<ck::index_t>& input_len,
const std::vector<ck::index_t>& strides,
const std::vector<ck::index_t>& dilations,
const std::vector<ck::index_t>& left_pads,
const std::vector<ck::index_t>& right_pads);
ck::index_t num_dim_spatial_;
ck::index_t N_;
ck::index_t K_;
ck::index_t C_;
std::vector<ck::index_t> filter_spatial_lengths_;
std::vector<ck::index_t> input_spatial_lengths_;
std::vector<ck::index_t> conv_filter_strides_;
std::vector<ck::index_t> conv_filter_dilations_;
std::vector<ck::index_t> input_left_pads_;
std::vector<ck::index_t> input_right_pads_;
std::vector<ck::index_t> GetOutputSpatialLengths() const;
};
ConvParams parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[]);
/**
* @brief Gets the host tensor descriptor.
*
* @param[in] dims The tensor dimensions lengths. Always in NCHW format.
* @param[in] layout The tensor data layout.
*
* @tparam TensorLayout Layout type.
*
* @return The host tensor descriptor object.
*/
template <typename TensorLayout>
HostTensorDescriptor get_host_tensor_descriptor(const std::vector<std::size_t>& dims,
const TensorLayout& layout)
{
std::size_t C = dims[1];
// 1D
if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NCW>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KCX>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NKW>::value)
{
return HostTensorDescriptor(dims, std::vector<std::size_t>{C * dims[2], dims[2], 1});
}
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NWC>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KXC>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NWK>::value)
{
return HostTensorDescriptor(dims, std::vector<std::size_t>{C * dims[2], 1, C});
}
// 2D
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NCHW>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KCYX>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NKHW>::value)
{
return HostTensorDescriptor(
dims, std::vector<std::size_t>{C * dims[2] * dims[3], dims[2] * dims[3], dims[3], 1});
}
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NHWC>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KYXC>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NHWK>::value)
{
return HostTensorDescriptor(
dims, std::vector<std::size_t>{C * dims[2] * dims[3], 1, dims[3] * C, C});
}
// 3D
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NCDHW>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KCZYX>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NKDHW>::value)
{
return HostTensorDescriptor(dims,
std::vector<std::size_t>{C * dims[2] * dims[3] * dims[4],
dims[2] * dims[3] * dims[4],
dims[3] * dims[4],
dims[4],
1});
}
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NDHWC>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KZYXC>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NDHWK>::value)
{
return HostTensorDescriptor(
dims,
std::vector<std::size_t>{
C * dims[2] * dims[3] * dims[4], 1, C * dims[3] * dims[4], C * dims[4], C});
}
std::stringstream err_msg;
err_msg << "Unsupported data layout provided: " << layout << "!";
throw std::runtime_error(err_msg.str());
}
HostTensorDescriptor get_output_host_tensor_descriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2);
HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2);
HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2);
template <ck::index_t NDim,
typename InDataType = float,
typename WeiDataType = float,
typename OutDataType = float>
void run_reference_convolution_forward(const ConvParams& params,
const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weights,
Tensor<OutDataType>& output)
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<InDataType,
WeiDataType,
OutDataType,
PassThrough,
PassThrough,
PassThrough,
NDim>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(input,
weights,
output,
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
params.input_right_pads_,
PassThrough{},
PassThrough{},
PassThrough{});
ref_invoker.Run(ref_argument);
}
template <typename InDataType, typename WeiDataType, typename OutDataType>
struct ConvolutionFwdInstances;
template <>
struct ConvolutionFwdInstances<float, float, float>
{
template <int NumDimSpatial,
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
static std::vector<DeviceConvFwdNoOpPtr> Get()
{
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1)
{
ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 2)
{
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 3)
{
ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs);
}
return conv_ptrs;
}
};
template <>
struct ConvolutionFwdInstances<half_t, half_t, half_t>
{
template <int NumDimSpatial,
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
static std::vector<DeviceConvFwdNoOpPtr> Get()
{
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1)
{
ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs);
return conv_ptrs;
}
else if constexpr(NumDimSpatial == 2)
{
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 3)
{
ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs);
}
return conv_ptrs;
}
};
template <>
struct ConvolutionFwdInstances<bhalf_t, bhalf_t, bhalf_t>
{
template <int NumDimSpatial,
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
static std::vector<DeviceConvFwdNoOpPtr> Get()
{
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1)
{
ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 2)
{
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 3)
{
ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs);
}
return conv_ptrs;
}
};
template <>
struct ConvolutionFwdInstances<int8_t, int8_t, int8_t>
{
template <int NumDimSpatial,
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
static std::vector<DeviceConvFwdNoOpPtr> Get()
{
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1)
{
ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 2)
{
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 3)
{
ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs);
}
return conv_ptrs;
}
};
template <typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InLayout = ck::tensor_layout::convolution::NHWC,
typename WeiLayout = ck::tensor_layout::convolution::KYXC,
typename OutLayout = ck::tensor_layout::convolution::NHWK,
typename InElementwiseOp = ck::tensor_operation::element_wise::PassThrough,
typename WeiElementwiseOp = ck::tensor_operation::element_wise::PassThrough,
typename OutElementwiseOp = ck::tensor_operation::element_wise::PassThrough,
typename InputInitFun = FillUniformDistribution<InDataType>,
typename WeightsInitFun = FillUniformDistribution<WeiDataType>>
class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType, WeiDataType>
{
using DeviceConvFwdOp = tensor_operation::device::
DeviceConvFwd<InElementwiseOp, WeiElementwiseOp, OutElementwiseOp>;
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
using DeviceBuffers = std::vector<DeviceMemPtr>;
using BaseType = ck::utils::OpInstance<OutDataType, InDataType, WeiDataType>;
template <typename T>
using TensorPtr = std::unique_ptr<Tensor<T>>;
using InTensorsTuple = std::tuple<TensorPtr<InDataType>, TensorPtr<WeiDataType>>;
public:
ConvFwdOpInstance() = delete;
ConvFwdOpInstance(const ConvFwdOpInstance&) = default;
ConvFwdOpInstance& operator=(const ConvFwdOpInstance&) = default;
ConvFwdOpInstance(const ConvParams& params,
bool do_init = true,
const InputInitFun& input_init_f = InputInitFun(),
const WeightsInitFun& weights_init_f = WeightsInitFun())
: BaseType(),
params_{params},
output_spatial_lengths_{params.GetOutputSpatialLengths()},
do_init_{do_init},
input_init_f_{input_init_f},
weights_init_f_{weights_init_f}
{
}
virtual ~ConvFwdOpInstance() override{};
virtual InTensorsTuple GetInputTensors() const override
{
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params_.N_),
static_cast<std::size_t>(params_.C_)};
input_dims.insert(std::end(input_dims),
std::begin(params_.input_spatial_lengths_),
std::end(params_.input_spatial_lengths_));
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params_.K_),
static_cast<std::size_t>(params_.C_)};
filter_dims.insert(std::end(filter_dims),
std::begin(params_.filter_spatial_lengths_),
std::end(params_.filter_spatial_lengths_));
auto input = std::make_unique<Tensor<InDataType>>(
get_host_tensor_descriptor(input_dims, InLayout{}));
auto weights = std::make_unique<Tensor<WeiDataType>>(
get_host_tensor_descriptor(filter_dims, WeiLayout{}));
if(do_init_)
{
input_init_f_(input->begin(), input->end());
weights_init_f_(weights->begin(), weights->end());
}
return std::make_tuple(std::move(input), std::move(weights));
}
virtual TensorPtr<OutDataType> GetOutputTensor() const override
{
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params_.N_),
static_cast<std::size_t>(params_.K_)};
output_dims.insert(std::end(output_dims),
std::begin(output_spatial_lengths_),
std::end(output_spatial_lengths_));
auto output = std::make_unique<Tensor<OutDataType>>(
get_host_tensor_descriptor(output_dims, OutLayout{}));
if(do_init_)
{
std::fill(output->begin(), output->end(), OutDataType(0.f));
}
return output;
}
virtual std::unique_ptr<tensor_operation::device::BaseInvoker>
MakeInvokerPointer(tensor_operation::device::BaseOperator* op_ptr) const override
{
static_assert(
std::is_same_v<InElementwiseOp, ck::tensor_operation::element_wise::PassThrough>);
static_assert(
std::is_same_v<OutElementwiseOp, ck::tensor_operation::element_wise::PassThrough>);
static_assert(
std::is_same_v<WeiElementwiseOp, ck::tensor_operation::element_wise::PassThrough>);
auto conv_ptr = dynamic_cast<DeviceConvFwdOp*>(op_ptr);
if(!conv_ptr)
{
throw std::runtime_error(
"[ConvFwdOpInstance]: couldn't cast op_ptr to DeviceConvFwdNoOpPtr type!");
}
return conv_ptr->MakeInvokerPointer();
}
virtual std::unique_ptr<tensor_operation::device::BaseArgument>
MakeArgumentPointer(tensor_operation::device::BaseOperator* op_ptr,
const DeviceBuffers& in_device_buffers,
const DeviceMemPtr& out_device_buffer) const override
{
static_assert(
std::is_same_v<InElementwiseOp, ck::tensor_operation::element_wise::PassThrough>);
static_assert(
std::is_same_v<OutElementwiseOp, ck::tensor_operation::element_wise::PassThrough>);
static_assert(
std::is_same_v<WeiElementwiseOp, ck::tensor_operation::element_wise::PassThrough>);
auto conv_ptr = dynamic_cast<DeviceConvFwdOp*>(op_ptr);
if(!conv_ptr)
{
throw std::runtime_error(
"[ConvFwdOpInstance]: couldn't cast op_ptr to DeviceConvFwdNoOpPtr type!");
}
return conv_ptr->MakeArgumentPointer(
static_cast<InDataType*>(in_device_buffers[0]->GetDeviceBuffer()),
static_cast<WeiDataType*>(in_device_buffers[1]->GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buffer->GetDeviceBuffer()),
params_.N_,
params_.K_,
params_.C_,
params_.input_spatial_lengths_,
params_.filter_spatial_lengths_,
output_spatial_lengths_,
params_.conv_filter_strides_,
params_.conv_filter_dilations_,
params_.input_left_pads_,
params_.input_right_pads_,
InElementwiseOp{},
WeiElementwiseOp{},
OutElementwiseOp{});
}
virtual std::size_t GetFlops() const override
{
return get_flops(params_.N_,
params_.C_,
params_.K_,
params_.filter_spatial_lengths_,
output_spatial_lengths_);
}
virtual std::size_t GetBtype() const override
{
return get_btype<InDataType, WeiDataType, OutDataType>(params_.N_,
params_.C_,
params_.K_,
params_.input_spatial_lengths_,
params_.filter_spatial_lengths_,
output_spatial_lengths_);
}
private:
const ConvParams& params_;
const std::vector<ck::index_t> output_spatial_lengths_;
const bool do_init_;
InputInitFun input_init_f_;
WeightsInitFun weights_init_f_;
};
} // namespace conv
} // namespace utils
} // namespace ck
std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParams& p);

View File

@@ -0,0 +1,354 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
namespace ck {
namespace utils {
namespace conv {
namespace detail {
template <typename OldLayout>
std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
{
// HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function,
// is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK
// TODO: remove this branch after removing legacy kernel
if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NWC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KXC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NWK>)
{
return {0, 1, 3, 2};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NHWC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KYXC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NHWK>)
{
return {0, 1, 4, 2, 3};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NDHWC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KZYXC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NDHWK>)
{
return {0, 1, 5, 2, 3, 4};
}
// separate from legacy code above
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNCW> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKCX> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNKW>)
{
return {0, 1, 2, 3};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNCHW> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKCYX> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNKHW>)
{
return {0, 1, 2, 3, 4};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNCDHW> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKCZYX> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNKDHW>)
{
return {0, 1, 2, 3, 4, 5};
}
if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNWC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKXC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNWK>)
{
return {0, 1, 3, 2};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNHWC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKYXC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNHWK>)
{
return {0, 1, 4, 2, 3};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNDHWC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKZYXC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNDHWK>)
{
return {0, 1, 5, 2, 3, 4};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NWGC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KXGC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NWGK>)
{
return {2, 0, 3, 1};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NHWGC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KYXGC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NHWGK>)
{
return {3, 0, 4, 1, 2};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NDHWGC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KZYXGC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NDHWGK>)
{
return {4, 0, 5, 1, 2, 3};
}
else
{
printf("%s\n", __func__);
throw std::runtime_error("wrong! unsupported layout");
}
}
} // namespace detail
// make tensor descriptor for packed input tensor, and order the dimension in the order of GNCHW
// regardless of physical layout
template <typename InLayout>
HostTensorDescriptor
make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck::utils::conv::ConvParam& param)
{
std::vector<std::size_t> physical_lengths;
// HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function,
// is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK
// TODO: remove this branch after removing legacy kernel
if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NHWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{
if(param.G_ != 1)
{
throw std::runtime_error("wrong! G != 1");
}
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
// separate from legacy code above
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCHW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCDHW>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNHWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNDHWC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWGC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NHWGC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NDHWGC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 1,
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else
{
printf("%s\n", __func__);
printf("%s\n", InLayout::name);
throw std::runtime_error("wrong! unsupported layout");
}
return transpose_host_tensor_descriptor_given_new2old(
HostTensorDescriptor(physical_lengths),
detail::get_layout_transpose_gnchw_to_old<InLayout>());
}
// make tensor descriptor for packed weight tensor, and order the dimension in the order of GKCYX
// regardless of physical layout
template <typename WeiLayout>
HostTensorDescriptor
make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck::utils::conv::ConvParam& param)
{
std::vector<std::size_t> physical_lengths;
// HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function,
// is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK
// TODO: remove this branch after removing legacy kernel
if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
if(param.G_ != 1)
{
throw std::runtime_error("wrong! G != 1");
}
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
// separate from legacy code above
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
if(param.G_ != 1)
{
throw std::runtime_error("wrong! G != 1");
}
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::GKCX> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::GKCYX> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::GKCZYX>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::GKXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::GKYXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::GKZYXC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXGC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXGC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXGC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 1,
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else
{
printf("%s\n", __func__);
printf("%s\n", WeiLayout::name);
throw std::runtime_error("wrong! unsupported layout");
}
return transpose_host_tensor_descriptor_given_new2old(
HostTensorDescriptor(physical_lengths),
detail::get_layout_transpose_gnchw_to_old<WeiLayout>());
}
// make tensor descriptor for packed output tensor, and order the dimension in the order of GNKHW
// regardless of physical layout
template <typename OutLayout>
HostTensorDescriptor
make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck::utils::conv::ConvParam& param)
{
std::vector<std::size_t> physical_lengths;
// HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function,
// is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK
// TODO: remove this branch after removing legacy kernel
if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{
if(param.G_ != 1)
{
throw std::runtime_error("wrong! G != 1");
}
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
// separate from legacy code above
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNKW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNKHW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNKDHW>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.end(),
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNHWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNDHWK>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NWGK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NHWGK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWGK>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.begin() + 1,
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else
{
printf("%s\n", __func__);
printf("%s\n", OutLayout::name);
throw std::runtime_error("wrong! unsupported layout");
}
return transpose_host_tensor_descriptor_given_new2old(
HostTensorDescriptor(physical_lengths),
detail::get_layout_transpose_gnchw_to_old<OutLayout>());
}
} // namespace conv
} // namespace utils
} // namespace ck

View File

@@ -0,0 +1,86 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <numeric>
#include <iterator>
#include <vector>
#include "ck/ck.hpp"
namespace ck {
namespace utils {
namespace conv {
struct ConvParam
{
ConvParam();
ConvParam(ck::index_t n_dim,
ck::index_t group_count,
ck::index_t n_batch,
ck::index_t n_out_channels,
ck::index_t n_in_channels,
const std::vector<ck::index_t>& filters_len,
const std::vector<ck::index_t>& input_len,
const std::vector<ck::index_t>& strides,
const std::vector<ck::index_t>& dilations,
const std::vector<ck::index_t>& left_pads,
const std::vector<ck::index_t>& right_pads);
ck::index_t num_dim_spatial_;
ck::index_t G_;
ck::index_t N_;
ck::index_t K_;
ck::index_t C_;
std::vector<ck::index_t> filter_spatial_lengths_;
std::vector<ck::index_t> input_spatial_lengths_;
std::vector<ck::index_t> output_spatial_lengths_;
std::vector<ck::index_t> conv_filter_strides_;
std::vector<ck::index_t> conv_filter_dilations_;
std::vector<ck::index_t> input_left_pads_;
std::vector<ck::index_t> input_right_pads_;
std::vector<ck::index_t> GetOutputSpatialLengths() const;
std::size_t GetFlops() const;
template <typename InDataType, typename WeiDataType, typename OutDataType>
std::size_t GetByte() const
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return sizeof(InDataType) *
(G_ * N_ * C_ *
std::accumulate(std::begin(input_spatial_lengths_),
std::begin(input_spatial_lengths_) + num_dim_spatial_,
static_cast<std::size_t>(1),
std::multiplies<std::size_t>())) +
sizeof(WeiDataType) *
(G_ * K_ * C_ *
std::accumulate(std::begin(filter_spatial_lengths_),
std::begin(filter_spatial_lengths_) + num_dim_spatial_,
static_cast<std::size_t>(1),
std::multiplies<std::size_t>())) +
sizeof(OutDataType) * (G_ * N_ * K_ *
std::accumulate(std::begin(output_spatial_lengths_),
std::end(output_spatial_lengths_),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>()));
}
};
std::string get_conv_param_parser_helper_msg();
ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[]);
} // namespace conv
} // namespace utils
} // namespace ck
std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p);

View File

@@ -11,8 +11,8 @@
#include "ck/utility/reduction_enums.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/library/host_tensor/host_common_util.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
template <int NDim>
static void get_all_indexes(const std::array<size_t, NDim>& dimLengths,

View File

@@ -73,22 +73,41 @@ auto construct_f_unpack_args(F, T args)
struct HostTensorDescriptor
{
HostTensorDescriptor() = delete;
template <typename X>
HostTensorDescriptor(const std::vector<X>& lens);
template <typename X, typename Y>
HostTensorDescriptor(const std::vector<X>& lens, const std::vector<Y>& strides);
HostTensorDescriptor() = default;
void CalculateStrides();
template <typename X>
HostTensorDescriptor(const std::initializer_list<X>& lens) : mLens(lens.begin(), lens.end())
{
this->CalculateStrides();
}
template <typename X>
HostTensorDescriptor(const std::vector<X>& lens) : mLens(lens.begin(), lens.end())
{
this->CalculateStrides();
}
template <typename Range>
HostTensorDescriptor(const Range& lens) : mLens(lens.begin(), lens.end())
{
this->CalculateStrides();
}
template <typename X, typename Y>
HostTensorDescriptor(const std::initializer_list<X>& lens,
const std::initializer_list<Y>& strides)
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
{
}
template <typename X, typename Y>
HostTensorDescriptor(const std::vector<X>& lens, const std::vector<Y>& strides)
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
{
}
template <typename Range1, typename Range2>
HostTensorDescriptor(const Range1& lens, const Range2& strides)
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
@@ -97,7 +116,7 @@ struct HostTensorDescriptor
std::size_t GetNumOfDimension() const;
std::size_t GetElementSize() const;
std::size_t GetElementSpace() const;
std::size_t GetElementSpaceSize() const;
const std::vector<std::size_t>& GetLengths() const;
const std::vector<std::size_t>& GetStrides() const;
@@ -122,6 +141,22 @@ struct HostTensorDescriptor
std::vector<std::size_t> mStrides;
};
template <typename New2Old>
HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor& a,
const New2Old& new2old)
{
std::vector<std::size_t> new_lengths(a.GetNumOfDimension());
std::vector<std::size_t> new_strides(a.GetNumOfDimension());
for(std::size_t i = 0; i < a.GetNumOfDimension(); i++)
{
new_lengths[i] = a.GetLengths()[new2old[i]];
new_strides[i] = a.GetStrides()[new2old[i]];
}
return HostTensorDescriptor(new_lengths, new_strides);
}
struct joinable_thread : std::thread
{
template <typename... Xs>
@@ -203,22 +238,22 @@ template <typename T>
struct Tensor
{
template <typename X>
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpace())
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
{
}
template <typename X>
Tensor(std::vector<X> lens) : mDesc(lens), mData(mDesc.GetElementSpace())
Tensor(std::vector<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
{
}
template <typename X, typename Y>
Tensor(std::vector<X> lens, std::vector<Y> strides)
: mDesc(lens, strides), mData(mDesc.GetElementSpace())
: mDesc(lens, strides), mData(mDesc.GetElementSpaceSize())
{
}
Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {}
Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {}
template <typename OutT>
Tensor<OutT> CopyAsType()
@@ -240,6 +275,24 @@ struct Tensor
return *this;
}
const std::vector<std::size_t>& GetLengths() const { return mDesc.GetLengths(); }
const std::vector<std::size_t>& GetStrides() const { return mDesc.GetStrides(); }
std::size_t GetNumOfDimension() const { return mDesc.GetNumOfDimension(); }
std::size_t GetElementSize() const { return mDesc.GetElementSize(); }
std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); }
void SetZero()
{
for(auto& v : mData)
{
v = T{0};
}
}
template <typename F>
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
{
@@ -330,6 +383,19 @@ struct Tensor
mDesc.GetLengths()[4])(num_thread);
break;
}
case 6: {
auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4, auto i5) {
(*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4, i5);
};
make_ParallelTensorFunctor(f,
mDesc.GetLengths()[0],
mDesc.GetLengths()[1],
mDesc.GetLengths()[2],
mDesc.GetLengths()[3],
mDesc.GetLengths()[4],
mDesc.GetLengths()[5])(num_thread);
break;
}
default: throw std::runtime_error("unspported dimension");
}
}
@@ -367,17 +433,3 @@ struct Tensor
HostTensorDescriptor mDesc;
std::vector<T> mData;
};
template <typename X>
HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens)
: mLens(lens.begin(), lens.end())
{
this->CalculateStrides();
}
template <typename X, typename Y>
HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens,
const std::vector<Y>& strides)
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
{
}

View File

@@ -16,8 +16,8 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace utils {
@@ -103,8 +103,8 @@ class OpInstanceRunEngine
}
}
AllocateDeviceInputTensors(std::make_index_sequence<kNInArgs_>{});
out_device_buffer_ =
std::make_unique<DeviceMem>(sizeof(OutDataType) * out_tensor_->mDesc.GetElementSpace());
out_device_buffer_ = std::make_unique<DeviceMem>(sizeof(OutDataType) *
out_tensor_->mDesc.GetElementSpaceSize());
out_device_buffer_->SetZero();
}
@@ -222,7 +222,7 @@ class OpInstanceRunEngine
in_device_buffers_
.emplace_back(
std::make_unique<DeviceMem>(sizeof(std::tuple_element_t<Index, InArgsTypesTuple>) *
ts->mDesc.GetElementSpace()))
ts->mDesc.GetElementSpaceSize()))
->ToDevice(ts->mData.data());
}