mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Clean up conv example, Instances, profiler and test (#324)
* convnd_fwd fp16 example
* update example
* update example
* update instance
* updating refernce conv
* update reference conv
* update conv fwd profiler
* update conv 1d and 3d instance
* update include path
* clean
* update profiler for conv bwd data and weight
* update conv bwd weight
* clean
* update conv example
* update profiler for conv bwd weight
* update ckprofiler for conv bwd data
* fix reference conv bwd data bug; update conv bwd data test
* update examples
* fix initialization issue
* update test for conv fwd
* clean
* clean
* remove test case too sensitive to error threshhold
* fix test
* clean
* fix build
* adding conv multiple d
* adding conv multiple D
* add matrix padder
* add gemm padding to convnd
* adding group conv
* update gemm multi-d
* refactor
* refactor
* refactor
* clean
* clean
* refactor
* refactor
* reorg
* add ds
* add bias
* clean
* add G
* adding group
* adding group
* adding group
* update Tensor
* clean
* update example
* update DeviceGemmMultipleD_Xdl_CShuffle
* update conv bwd-data and bwd-weight
* upate contraction example
* update gemm and batch gemm with e permute
* fix example build
* instance for grouped conv1d
* update example
* adding group conv instance
* update gemm bilinear instance
* update gemm+add+add+fastgelu instance
* update profiler
* update profiler
* update test
* update test and client example
* clean
* add grouped conv into profiler
* update profiler
* clean
* add test grouped conv, update all conv test to gtest
* update test
[ROCm/composable_kernel commit: 500fa99512]
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -8,22 +8,24 @@
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X]
|
||||
template <typename InDataType,
|
||||
// input descriptor in [G, N, C, Do, Ho, Wo] order
|
||||
// weight descriptor in [G, K, C, Z, Y, X] order
|
||||
// output descriptor in [G, N, K, Di, Hi, Wi] order
|
||||
// phyiscal layout is irrelavent
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ck::index_t NumDimSpatial = 2,
|
||||
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
|
||||
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
|
||||
struct ReferenceConvBwdData : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
@@ -73,36 +75,45 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
if constexpr(NumDimSpatial == 1)
|
||||
if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
|
||||
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
|
||||
arg.output_.GetNumOfDimension() == NDimSpatial + 3))
|
||||
{
|
||||
auto f_ncw = [&](auto n, auto c, auto wi) {
|
||||
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
|
||||
std::size_t X = arg.weight_.mDesc.GetLengths()[2];
|
||||
std::size_t Wo = arg.output_.mDesc.GetLengths()[2];
|
||||
throw std::runtime_error("wrong! inconsistent dimension");
|
||||
}
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
auto f_ncw = [&](auto g, auto n, auto c, auto wi) {
|
||||
std::size_t K = arg.weight_.GetLengths()[1];
|
||||
std::size_t X = arg.weight_.GetLengths()[3];
|
||||
std::size_t Wo = arg.output_.GetLengths()[3];
|
||||
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t x = 0; x < X; ++x)
|
||||
{
|
||||
auto w_tmp = ck::type_convert<ck::long_index_t>(wi) +
|
||||
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
|
||||
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]);
|
||||
auto w_tmp = static_cast<ck::long_index_t>(wi) +
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
|
||||
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]);
|
||||
|
||||
if(w_tmp % arg.conv_strides_[0] == 0)
|
||||
{
|
||||
auto wo = ck::type_convert<ck::long_index_t>(w_tmp) /
|
||||
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
|
||||
auto wo = static_cast<ck::long_index_t>(w_tmp) /
|
||||
static_cast<ck::long_index_t>(arg.conv_strides_[0]);
|
||||
|
||||
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
|
||||
{
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
AccDataType v_out = 0;
|
||||
AccDataType v_wei = 0;
|
||||
float v_out = 0;
|
||||
float v_wei = 0;
|
||||
|
||||
arg.out_element_op_(
|
||||
v_out,
|
||||
ck::type_convert<AccDataType>(arg.output_(n, k, wo)));
|
||||
v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
|
||||
|
||||
arg.wei_element_op_(
|
||||
v_wei, ck::type_convert<AccDataType>(arg.weight_(k, c, x)));
|
||||
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x)));
|
||||
|
||||
v_acc += v_out * v_wei;
|
||||
}
|
||||
@@ -110,66 +121,72 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
}
|
||||
}
|
||||
|
||||
arg.in_element_op_(v_acc, v_acc);
|
||||
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_acc);
|
||||
float v_in;
|
||||
|
||||
arg.in_element_op_(v_in, v_acc);
|
||||
|
||||
arg.input_(g, n, c, wi) = ck::type_convert<InDataType>(v_acc);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ncw,
|
||||
arg.input_.mDesc.GetLengths()[0],
|
||||
arg.input_.mDesc.GetLengths()[1],
|
||||
arg.input_.mDesc.GetLengths()[2])(
|
||||
arg.input_.GetLengths()[0],
|
||||
arg.input_.GetLengths()[1],
|
||||
arg.input_.GetLengths()[2],
|
||||
arg.input_.GetLengths()[3])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2)
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
|
||||
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
|
||||
std::size_t Y = arg.weight_.mDesc.GetLengths()[2];
|
||||
std::size_t X = arg.weight_.mDesc.GetLengths()[3];
|
||||
auto f_nchw = [&](auto g, auto n, auto c, auto hi, auto wi) {
|
||||
std::size_t K = arg.weight_.GetLengths()[1];
|
||||
std::size_t Y = arg.weight_.GetLengths()[3];
|
||||
std::size_t X = arg.weight_.GetLengths()[4];
|
||||
|
||||
std::size_t Ho = arg.output_.mDesc.GetLengths()[2];
|
||||
std::size_t Wo = arg.output_.mDesc.GetLengths()[3];
|
||||
std::size_t Ho = arg.output_.GetLengths()[3];
|
||||
std::size_t Wo = arg.output_.GetLengths()[4];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t y = 0; y < Y; ++y)
|
||||
{
|
||||
auto h_tmp = ck::type_convert<ck::long_index_t>(hi) +
|
||||
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
|
||||
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]);
|
||||
auto h_tmp = static_cast<ck::long_index_t>(hi) +
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
|
||||
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]);
|
||||
if(h_tmp % arg.conv_strides_[0] == 0)
|
||||
{
|
||||
auto ho = ck::type_convert<ck::long_index_t>(h_tmp) /
|
||||
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
|
||||
auto ho = static_cast<ck::long_index_t>(h_tmp) /
|
||||
static_cast<ck::long_index_t>(arg.conv_strides_[0]);
|
||||
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
|
||||
{
|
||||
for(std::size_t x = 0; x < X; ++x)
|
||||
{
|
||||
auto w_tmp =
|
||||
ck::type_convert<ck::long_index_t>(wi) +
|
||||
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) -
|
||||
ck::type_convert<ck::long_index_t>(x *
|
||||
arg.conv_dilations_[1]);
|
||||
static_cast<ck::long_index_t>(wi) +
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
|
||||
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]);
|
||||
if(w_tmp % arg.conv_strides_[1] == 0)
|
||||
{
|
||||
auto wo = ck::type_convert<ck::long_index_t>(w_tmp) /
|
||||
ck::type_convert<ck::long_index_t>(
|
||||
arg.conv_strides_[1]);
|
||||
auto wo =
|
||||
static_cast<ck::long_index_t>(w_tmp) /
|
||||
static_cast<ck::long_index_t>(arg.conv_strides_[1]);
|
||||
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
|
||||
{
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
AccDataType v_out = 0;
|
||||
AccDataType v_wei = 0;
|
||||
float v_out = 0;
|
||||
float v_wei = 0;
|
||||
|
||||
arg.out_element_op_(v_out,
|
||||
ck::type_convert<AccDataType>(
|
||||
arg.output_(n, k, ho, wo)));
|
||||
arg.wei_element_op_(v_wei,
|
||||
ck::type_convert<AccDataType>(
|
||||
arg.weight_(k, c, y, x)));
|
||||
arg.out_element_op_(
|
||||
v_out,
|
||||
ck::type_convert<float>(
|
||||
arg.output_(g, n, k, ho, wo)));
|
||||
|
||||
arg.wei_element_op_(
|
||||
v_wei,
|
||||
ck::type_convert<float>(
|
||||
arg.weight_(g, k, c, y, x)));
|
||||
|
||||
v_acc += v_out * v_wei;
|
||||
}
|
||||
@@ -180,90 +197,91 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
}
|
||||
}
|
||||
|
||||
AccDataType v_in;
|
||||
float v_in;
|
||||
|
||||
arg.in_element_op_(v_in, v_acc);
|
||||
arg.input_(n, c, hi, wi) = ck::type_convert<InDataType>(v_in);
|
||||
|
||||
arg.input_(g, n, c, hi, wi) = ck::type_convert<InDataType>(v_acc);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
arg.input_.mDesc.GetLengths()[0],
|
||||
arg.input_.mDesc.GetLengths()[1],
|
||||
arg.input_.mDesc.GetLengths()[2],
|
||||
arg.input_.mDesc.GetLengths()[3])(
|
||||
arg.input_.GetLengths()[0],
|
||||
arg.input_.GetLengths()[1],
|
||||
arg.input_.GetLengths()[2],
|
||||
arg.input_.GetLengths()[3],
|
||||
arg.input_.GetLengths()[4])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3)
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) {
|
||||
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
|
||||
std::size_t Z = arg.weight_.mDesc.GetLengths()[2];
|
||||
std::size_t Y = arg.weight_.mDesc.GetLengths()[3];
|
||||
std::size_t X = arg.weight_.mDesc.GetLengths()[4];
|
||||
auto f_ncdhw = [&](auto g, auto n, auto c, auto di, auto hi, auto wi) {
|
||||
std::size_t K = arg.weight_.GetLengths()[1];
|
||||
std::size_t Z = arg.weight_.GetLengths()[3];
|
||||
std::size_t Y = arg.weight_.GetLengths()[4];
|
||||
std::size_t X = arg.weight_.GetLengths()[5];
|
||||
|
||||
std::size_t Do = arg.output_.mDesc.GetLengths()[2];
|
||||
std::size_t Ho = arg.output_.mDesc.GetLengths()[3];
|
||||
std::size_t Wo = arg.output_.mDesc.GetLengths()[4];
|
||||
std::size_t Do = arg.output_.GetLengths()[3];
|
||||
std::size_t Ho = arg.output_.GetLengths()[4];
|
||||
std::size_t Wo = arg.output_.GetLengths()[5];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t z = 0; z < Z; ++z)
|
||||
{
|
||||
auto d_tmp = ck::type_convert<ck::long_index_t>(di) +
|
||||
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
|
||||
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]);
|
||||
auto d_tmp = static_cast<ck::long_index_t>(di) +
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
|
||||
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]);
|
||||
if(d_tmp % arg.conv_strides_[0] == 0)
|
||||
{
|
||||
auto do_ = ck::type_convert<ck::long_index_t>(d_tmp) /
|
||||
ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
|
||||
auto do_ = static_cast<ck::long_index_t>(d_tmp) /
|
||||
static_cast<ck::long_index_t>(arg.conv_strides_[0]);
|
||||
if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
|
||||
{
|
||||
for(std::size_t y = 0; y < Y; ++y)
|
||||
{
|
||||
auto h_tmp =
|
||||
ck::type_convert<ck::long_index_t>(hi) +
|
||||
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) -
|
||||
ck::type_convert<ck::long_index_t>(y *
|
||||
arg.conv_dilations_[1]);
|
||||
static_cast<ck::long_index_t>(hi) +
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
|
||||
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]);
|
||||
if(h_tmp % arg.conv_strides_[1] == 0)
|
||||
{
|
||||
auto ho = ck::type_convert<ck::long_index_t>(h_tmp) /
|
||||
ck::type_convert<ck::long_index_t>(
|
||||
arg.conv_strides_[1]);
|
||||
auto ho =
|
||||
static_cast<ck::long_index_t>(h_tmp) /
|
||||
static_cast<ck::long_index_t>(arg.conv_strides_[1]);
|
||||
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
|
||||
{
|
||||
for(std::size_t x = 0; x < X; ++x)
|
||||
{
|
||||
auto w_tmp =
|
||||
ck::type_convert<ck::long_index_t>(wi) +
|
||||
ck::type_convert<ck::long_index_t>(
|
||||
arg.in_left_pads_[2]) -
|
||||
ck::type_convert<ck::long_index_t>(
|
||||
x * arg.conv_dilations_[2]);
|
||||
auto w_tmp = static_cast<ck::long_index_t>(wi) +
|
||||
static_cast<ck::long_index_t>(
|
||||
arg.in_left_pads_[2]) -
|
||||
static_cast<ck::long_index_t>(
|
||||
x * arg.conv_dilations_[2]);
|
||||
|
||||
if(w_tmp % arg.conv_strides_[2] == 0)
|
||||
{
|
||||
auto wo =
|
||||
ck::type_convert<ck::long_index_t>(w_tmp) /
|
||||
ck::type_convert<ck::long_index_t>(
|
||||
arg.conv_strides_[2]);
|
||||
auto wo = static_cast<ck::long_index_t>(w_tmp) /
|
||||
static_cast<ck::long_index_t>(
|
||||
arg.conv_strides_[2]);
|
||||
if(wo >= 0 &&
|
||||
ck::type_convert<std::size_t>(wo) < Wo)
|
||||
{
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
AccDataType v_out = 0;
|
||||
AccDataType v_wei = 0;
|
||||
float v_out = 0;
|
||||
float v_wei = 0;
|
||||
|
||||
arg.out_element_op_(
|
||||
v_out,
|
||||
ck::type_convert<AccDataType>(
|
||||
arg.output_(
|
||||
n, k, do_, ho, wo)));
|
||||
ck::type_convert<float>(arg.output_(
|
||||
g, n, k, do_, ho, wo)));
|
||||
|
||||
arg.wei_element_op_(
|
||||
v_wei,
|
||||
ck::type_convert<AccDataType>(
|
||||
arg.weight_(k, c, z, y, x)));
|
||||
ck::type_convert<float>(
|
||||
arg.weight_(g, k, c, z, y, x)));
|
||||
|
||||
v_acc += v_out * v_wei;
|
||||
}
|
||||
@@ -277,17 +295,20 @@ struct ReferenceConvBwdData : public device::BaseOperator
|
||||
}
|
||||
}
|
||||
|
||||
AccDataType v_in;
|
||||
float v_in;
|
||||
|
||||
arg.in_element_op_(v_in, v_acc);
|
||||
arg.input_(n, c, di, hi, wi) = ck::type_convert<InDataType>(v_in);
|
||||
|
||||
arg.input_(g, n, c, di, hi, wi) = ck::type_convert<InDataType>(v_acc);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ncdhw,
|
||||
arg.input_.mDesc.GetLengths()[0],
|
||||
arg.input_.mDesc.GetLengths()[1],
|
||||
arg.input_.mDesc.GetLengths()[2],
|
||||
arg.input_.mDesc.GetLengths()[3],
|
||||
arg.input_.mDesc.GetLengths()[4])(
|
||||
arg.input_.GetLengths()[0],
|
||||
arg.input_.GetLengths()[1],
|
||||
arg.input_.GetLengths()[2],
|
||||
arg.input_.GetLengths()[3],
|
||||
arg.input_.GetLengths()[4],
|
||||
arg.input_.GetLengths()[5])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -7,21 +7,25 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X]
|
||||
template <typename InDataType,
|
||||
// input descriptor in [G, N, C, Do, Ho, Wo] order
|
||||
// weight descriptor in [G, K, C, Z, Y, X] order
|
||||
// output descriptor in [G, N, K, Di, Hi, Wi] order
|
||||
// phyiscal layout is irrelavent
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ck::index_t NumDimSpatial = 2,
|
||||
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
|
||||
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
|
||||
struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
@@ -71,156 +75,162 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
if constexpr(NumDimSpatial == 1)
|
||||
if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
|
||||
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
|
||||
arg.output_.GetNumOfDimension() == NDimSpatial + 3))
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
auto f_kcx = [&](auto k, auto c, auto x) {
|
||||
throw std::runtime_error("wrong! inconsistent dimension");
|
||||
}
|
||||
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
auto f_kcx = [&](auto g, auto k, auto c, auto x) {
|
||||
float v_acc = 0;
|
||||
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
|
||||
|
||||
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
|
||||
{
|
||||
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[2]; ++wo)
|
||||
for(std::size_t wo = 0; wo < arg.output_.GetLengths()[3]; ++wo)
|
||||
{
|
||||
auto wi =
|
||||
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I0]) +
|
||||
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[I0]) -
|
||||
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
|
||||
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
|
||||
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
|
||||
|
||||
if(wi >= 0 &&
|
||||
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2])
|
||||
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
|
||||
{
|
||||
float v_out;
|
||||
float v_in;
|
||||
|
||||
arg.out_element_op_(v_out,
|
||||
ck::type_convert<float>(arg.output_(n, k, wo)));
|
||||
arg.in_element_op_(v_in,
|
||||
ck::type_convert<float>(arg.input_(n, c, wi)));
|
||||
arg.out_element_op_(
|
||||
v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
|
||||
|
||||
arg.in_element_op_(
|
||||
v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
|
||||
|
||||
v_acc += v_out * v_in;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float v_wei;
|
||||
|
||||
arg.wei_element_op_(v_wei, v_acc);
|
||||
|
||||
arg.weight_(k, c, x) = ck::type_convert<WeiDataType>(v_wei);
|
||||
arg.weight_(g, k, c, x) = ck::type_convert<WeiDataType>(v_wei);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_kcx,
|
||||
arg.weight_.mDesc.GetLengths()[0],
|
||||
arg.weight_.mDesc.GetLengths()[1],
|
||||
arg.weight_.mDesc.GetLengths()[2])(
|
||||
arg.weight_.GetLengths()[0],
|
||||
arg.weight_.GetLengths()[1],
|
||||
arg.weight_.GetLengths()[2],
|
||||
arg.weight_.GetLengths()[3])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2)
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
|
||||
auto f_kcyx = [&](auto g, auto k, auto c, auto y, auto x) {
|
||||
float v_acc = 0;
|
||||
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
|
||||
|
||||
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
|
||||
{
|
||||
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[2]; ++ho)
|
||||
for(std::size_t ho = 0; ho < arg.output_.GetLengths()[3]; ++ho)
|
||||
{
|
||||
auto hi =
|
||||
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I0]) +
|
||||
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[I0]) -
|
||||
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
|
||||
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[3]; ++wo)
|
||||
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
|
||||
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
|
||||
|
||||
for(std::size_t wo = 0; wo < arg.output_.GetLengths()[4]; ++wo)
|
||||
{
|
||||
auto wi =
|
||||
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I1]) +
|
||||
ck::type_convert<ck::long_index_t>(x *
|
||||
arg.conv_dilations_[I1]) -
|
||||
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
|
||||
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
|
||||
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
|
||||
|
||||
if(hi >= 0 &&
|
||||
ck::type_convert<std::size_t>(hi) <
|
||||
arg.input_.mDesc.GetLengths()[2] &&
|
||||
ck::type_convert<std::size_t>(hi) < arg.input_.GetLengths()[3] &&
|
||||
wi >= 0 &&
|
||||
ck::type_convert<std::size_t>(wi) <
|
||||
arg.input_.mDesc.GetLengths()[3])
|
||||
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
|
||||
{
|
||||
float v_out;
|
||||
float v_in;
|
||||
|
||||
arg.out_element_op_(
|
||||
v_out, ck::type_convert<float>(arg.output_(n, k, ho, wo)));
|
||||
v_out,
|
||||
ck::type_convert<float>(arg.output_(g, n, k, ho, wo)));
|
||||
|
||||
arg.in_element_op_(
|
||||
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi)));
|
||||
v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
|
||||
|
||||
v_acc += v_out * v_in;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float v_wei;
|
||||
|
||||
arg.wei_element_op_(v_wei, v_acc);
|
||||
|
||||
arg.weight_(k, c, y, x) = ck::type_convert<WeiDataType>(v_wei);
|
||||
arg.weight_(g, k, c, y, x) = ck::type_convert<WeiDataType>(v_wei);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_kcyx,
|
||||
arg.weight_.mDesc.GetLengths()[0],
|
||||
arg.weight_.mDesc.GetLengths()[1],
|
||||
arg.weight_.mDesc.GetLengths()[2],
|
||||
arg.weight_.mDesc.GetLengths()[3])(
|
||||
arg.weight_.GetLengths()[0],
|
||||
arg.weight_.GetLengths()[1],
|
||||
arg.weight_.GetLengths()[2],
|
||||
arg.weight_.GetLengths()[3],
|
||||
arg.weight_.GetLengths()[4])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3)
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
auto f_kczyx = [&](auto k, auto c, auto z, auto y, auto x) {
|
||||
auto f_kczyx = [&](auto g, auto k, auto c, auto z, auto y, auto x) {
|
||||
float v_acc = 0;
|
||||
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
|
||||
|
||||
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
|
||||
{
|
||||
for(std::size_t do_ = 0; do_ < arg.output_.mDesc.GetLengths()[2]; ++do_)
|
||||
for(std::size_t do_ = 0; do_ < arg.output_.GetLengths()[3]; ++do_)
|
||||
{
|
||||
auto di =
|
||||
ck::type_convert<ck::long_index_t>(do_ * arg.conv_strides_[I0]) +
|
||||
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[I0]) -
|
||||
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
|
||||
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[3]; ++ho)
|
||||
auto di = static_cast<ck::long_index_t>(do_ * arg.conv_strides_[0]) +
|
||||
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
|
||||
for(std::size_t ho = 0; ho < arg.output_.GetLengths()[4]; ++ho)
|
||||
{
|
||||
auto hi =
|
||||
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I1]) +
|
||||
ck::type_convert<ck::long_index_t>(y *
|
||||
arg.conv_dilations_[I1]) -
|
||||
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
|
||||
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[4];
|
||||
++wo)
|
||||
static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
|
||||
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
|
||||
for(std::size_t wo = 0; wo < arg.output_.GetLengths()[5]; ++wo)
|
||||
{
|
||||
auto wi =
|
||||
ck::type_convert<ck::long_index_t>(wo *
|
||||
arg.conv_strides_[I2]) +
|
||||
ck::type_convert<ck::long_index_t>(
|
||||
x * arg.conv_dilations_[I2]) -
|
||||
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I2]);
|
||||
static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
|
||||
static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
|
||||
|
||||
if(di >= 0 &&
|
||||
ck::type_convert<std::size_t>(di) <
|
||||
arg.input_.mDesc.GetLengths()[2] &&
|
||||
arg.input_.GetLengths()[3] &&
|
||||
hi >= 0 &&
|
||||
ck::type_convert<std::size_t>(hi) <
|
||||
arg.input_.mDesc.GetLengths()[3] &&
|
||||
arg.input_.GetLengths()[4] &&
|
||||
wi >= 0 &&
|
||||
ck::type_convert<std::size_t>(wi) <
|
||||
arg.input_.mDesc.GetLengths()[4])
|
||||
arg.input_.GetLengths()[5])
|
||||
{
|
||||
float v_out;
|
||||
float v_in;
|
||||
|
||||
arg.out_element_op_(v_out,
|
||||
ck::type_convert<float>(
|
||||
arg.output_(n, k, do_, ho, wo)));
|
||||
arg.in_element_op_(
|
||||
v_in,
|
||||
ck::type_convert<float>(arg.input_(n, c, di, hi, wi)));
|
||||
arg.output_(g, n, k, do_, ho, wo)));
|
||||
|
||||
arg.in_element_op_(v_in,
|
||||
ck::type_convert<float>(
|
||||
arg.input_(g, n, c, di, hi, wi)));
|
||||
|
||||
v_acc += v_out * v_in;
|
||||
}
|
||||
@@ -228,19 +238,21 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float v_wei;
|
||||
|
||||
arg.wei_element_op_(v_wei, v_acc);
|
||||
|
||||
arg.weight_(k, c, z, y, x) = ck::type_convert<WeiDataType>(v_wei);
|
||||
arg.weight_(g, k, c, z, y, x) = ck::type_convert<WeiDataType>(v_wei);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_kczyx,
|
||||
arg.weight_.mDesc.GetLengths()[0],
|
||||
arg.weight_.mDesc.GetLengths()[1],
|
||||
arg.weight_.mDesc.GetLengths()[2],
|
||||
arg.weight_.mDesc.GetLengths()[3],
|
||||
arg.weight_.mDesc.GetLengths()[4])(
|
||||
arg.weight_.GetLengths()[0],
|
||||
arg.weight_.GetLengths()[1],
|
||||
arg.weight_.GetLengths()[2],
|
||||
arg.weight_.GetLengths()[3],
|
||||
arg.weight_.GetLengths()[4],
|
||||
arg.weight_.GetLengths()[5])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
@@ -8,7 +8,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -17,9 +17,10 @@ namespace host {
|
||||
//
|
||||
// @brief Reference implementation for forward convolution.
|
||||
//
|
||||
// @paragraph Supports both NCHW as well as NHWC formats (and their respective
|
||||
// counterparts for weight and output) as long as tensor descriptor
|
||||
// lengths is in NCHW.
|
||||
// @paragraph
|
||||
// Tensor descriptor in GNCHW/GKCXY/GNKHW dimensional order
|
||||
// Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout
|
||||
// as long as dimensions in tensor descriptor is in GNCHW order
|
||||
//
|
||||
// @tparam InDataType Input tensor data type.
|
||||
// @tparam WeiDataType Weights tensor data type.
|
||||
@@ -28,16 +29,20 @@ namespace host {
|
||||
// operation.
|
||||
// @tparam WeiElementwiseOperation Functor for weights tensor elementwise
|
||||
// operation.
|
||||
// @tparam NumDimSpatial Number of spatial dimensions.
|
||||
// @tparam NDimSpatial Number of spatial dimensions.
|
||||
//
|
||||
template <typename InDataType,
|
||||
// input descriptor in [G, N, C, Do, Ho, Wo] order
|
||||
// weight descriptor in [G, K, C, Z, Y, X] order
|
||||
// output descriptor in [G, N, K, Di, Hi, Wi] order
|
||||
// phyiscal layout is irrelavent
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ck::index_t NumDimSpatial = 2,
|
||||
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
|
||||
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
|
||||
struct ReferenceConvFwd : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
@@ -86,29 +91,37 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
if constexpr(NumDimSpatial == 1)
|
||||
if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
|
||||
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
|
||||
arg.output_.GetNumOfDimension() == NDimSpatial + 3))
|
||||
{
|
||||
auto f_ncw = [&](auto n, auto k, auto wo) {
|
||||
throw std::runtime_error("wrong! inconsistent dimension");
|
||||
}
|
||||
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
auto func = [&](auto g, auto n, auto k, auto wo) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
|
||||
for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
|
||||
{
|
||||
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x)
|
||||
for(std::size_t x = 0; x < arg.weight_.GetLengths()[3]; ++x)
|
||||
{
|
||||
auto wi =
|
||||
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[0]) +
|
||||
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]) -
|
||||
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
|
||||
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
|
||||
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
|
||||
|
||||
if(wi >= 0 &&
|
||||
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2])
|
||||
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
|
||||
{
|
||||
float v_in;
|
||||
float v_wei;
|
||||
|
||||
arg.in_element_op_(v_in,
|
||||
ck::type_convert<float>(arg.input_(n, c, wi)));
|
||||
arg.wei_element_op_(v_wei,
|
||||
ck::type_convert<float>(arg.weight_(k, c, x)));
|
||||
arg.in_element_op_(
|
||||
v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
|
||||
|
||||
arg.wei_element_op_(
|
||||
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x)));
|
||||
|
||||
v_acc += v_in * v_wei;
|
||||
}
|
||||
@@ -118,50 +131,53 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
float v_out;
|
||||
|
||||
arg.out_element_op_(v_out, v_acc);
|
||||
arg.output_(n, k, wo) = ck::type_convert<OutDataType>(v_out);
|
||||
|
||||
arg.output_(g, n, k, wo) = ck::type_convert<OutDataType>(v_out);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ncw,
|
||||
arg.output_.mDesc.GetLengths()[0],
|
||||
arg.output_.mDesc.GetLengths()[1],
|
||||
arg.output_.mDesc.GetLengths()[2])(
|
||||
make_ParallelTensorFunctor(func,
|
||||
arg.output_.GetLengths()[0],
|
||||
arg.output_.GetLengths()[1],
|
||||
arg.output_.GetLengths()[2],
|
||||
arg.output_.GetLengths()[3])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2)
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
auto func = [&](auto g, auto n, auto k, auto ho, auto wo) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
|
||||
for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
|
||||
{
|
||||
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y)
|
||||
for(std::size_t y = 0; y < arg.weight_.GetLengths()[3]; ++y)
|
||||
{
|
||||
auto hi =
|
||||
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
|
||||
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
|
||||
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
|
||||
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x)
|
||||
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
|
||||
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
|
||||
|
||||
for(std::size_t x = 0; x < arg.weight_.GetLengths()[4]; ++x)
|
||||
{
|
||||
auto wi =
|
||||
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
|
||||
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
|
||||
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
|
||||
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
|
||||
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
|
||||
|
||||
if(hi >= 0 &&
|
||||
ck::type_convert<std::size_t>(hi) <
|
||||
arg.input_.mDesc.GetLengths()[2] &&
|
||||
ck::type_convert<std::size_t>(hi) < arg.input_.GetLengths()[3] &&
|
||||
wi >= 0 &&
|
||||
ck::type_convert<std::size_t>(wi) <
|
||||
arg.input_.mDesc.GetLengths()[3])
|
||||
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
|
||||
{
|
||||
float v_in;
|
||||
float v_wei;
|
||||
|
||||
arg.in_element_op_(
|
||||
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi)));
|
||||
v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
|
||||
|
||||
arg.wei_element_op_(
|
||||
v_wei, ck::type_convert<float>(arg.weight_(k, c, y, x)));
|
||||
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, y, x)));
|
||||
|
||||
v_acc += v_in * v_wei;
|
||||
}
|
||||
}
|
||||
@@ -171,64 +187,65 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
float v_out;
|
||||
|
||||
arg.out_element_op_(v_out, v_acc);
|
||||
arg.output_(n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
|
||||
|
||||
arg.output_(g, n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
arg.output_.mDesc.GetLengths()[0],
|
||||
arg.output_.mDesc.GetLengths()[1],
|
||||
arg.output_.mDesc.GetLengths()[2],
|
||||
arg.output_.mDesc.GetLengths()[3])(
|
||||
make_ParallelTensorFunctor(func,
|
||||
arg.output_.GetLengths()[0],
|
||||
arg.output_.GetLengths()[1],
|
||||
arg.output_.GetLengths()[2],
|
||||
arg.output_.GetLengths()[3],
|
||||
arg.output_.GetLengths()[4])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3)
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) {
|
||||
auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
|
||||
for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
|
||||
{
|
||||
for(std::size_t z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z)
|
||||
for(std::size_t z = 0; z < arg.weight_.GetLengths()[3]; ++z)
|
||||
{
|
||||
auto di =
|
||||
ck::type_convert<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
|
||||
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]) -
|
||||
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
|
||||
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y)
|
||||
auto di = static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
|
||||
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
|
||||
for(std::size_t y = 0; y < arg.weight_.GetLengths()[4]; ++y)
|
||||
{
|
||||
auto hi =
|
||||
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[1]) +
|
||||
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[1]) -
|
||||
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
|
||||
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x)
|
||||
static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
|
||||
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
|
||||
for(std::size_t x = 0; x < arg.weight_.GetLengths()[5]; ++x)
|
||||
{
|
||||
auto wi =
|
||||
ck::type_convert<ck::long_index_t>(wo *
|
||||
arg.conv_strides_[2]) +
|
||||
ck::type_convert<ck::long_index_t>(x *
|
||||
arg.conv_dilations_[2]) -
|
||||
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[2]);
|
||||
static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
|
||||
static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) -
|
||||
static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
|
||||
if(di >= 0 &&
|
||||
ck::type_convert<std::size_t>(di) <
|
||||
arg.input_.mDesc.GetLengths()[2] &&
|
||||
arg.input_.GetLengths()[3] &&
|
||||
hi >= 0 &&
|
||||
ck::type_convert<std::size_t>(hi) <
|
||||
arg.input_.mDesc.GetLengths()[3] &&
|
||||
arg.input_.GetLengths()[4] &&
|
||||
wi >= 0 &&
|
||||
ck::type_convert<std::size_t>(wi) <
|
||||
arg.input_.mDesc.GetLengths()[4])
|
||||
arg.input_.GetLengths()[5])
|
||||
{
|
||||
float v_in;
|
||||
float v_wei;
|
||||
|
||||
arg.in_element_op_(
|
||||
v_in,
|
||||
ck::type_convert<float>(arg.input_(n, c, di, hi, wi)));
|
||||
arg.in_element_op_(v_in,
|
||||
ck::type_convert<float>(
|
||||
arg.input_(g, n, c, di, hi, wi)));
|
||||
|
||||
arg.wei_element_op_(
|
||||
v_wei,
|
||||
ck::type_convert<float>(arg.weight_(k, c, z, y, x)));
|
||||
ck::type_convert<float>(arg.weight_(g, k, c, z, y, x)));
|
||||
|
||||
v_acc += v_in * v_wei;
|
||||
}
|
||||
}
|
||||
@@ -239,15 +256,17 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
float v_out;
|
||||
|
||||
arg.out_element_op_(v_out, v_acc);
|
||||
arg.output_(n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
|
||||
|
||||
arg.output_(g, n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
arg.output_.mDesc.GetLengths()[0],
|
||||
arg.output_.mDesc.GetLengths()[1],
|
||||
arg.output_.mDesc.GetLengths()[2],
|
||||
arg.output_.mDesc.GetLengths()[3],
|
||||
arg.output_.mDesc.GetLengths()[4])(
|
||||
make_ParallelTensorFunctor(func,
|
||||
arg.output_.GetLengths()[0],
|
||||
arg.output_.GetLengths()[1],
|
||||
arg.output_.GetLengths()[2],
|
||||
arg.output_.GetLengths()[3],
|
||||
arg.output_.GetLengths()[4],
|
||||
arg.output_.GetLengths()[5])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
@@ -267,7 +286,10 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override
|
||||
{
|
||||
return NDimSpatial >= 1 && NDimSpatial <= 3;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<InDataType>& input,
|
||||
const Tensor<WeiDataType>& weight,
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -9,8 +9,8 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -9,8 +9,8 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -10,22 +10,67 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// aliasing, for commonly used type
|
||||
// aliasing, for commonly used data type
|
||||
using F64 = double;
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
|
||||
using EMPTY_TUPLE = ck::Tuple<>;
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
using F16_TUPLE = ck::Tuple<F16>;
|
||||
using F16_F16_TUPLE = ck::Tuple<F16, F16>;
|
||||
using F16_Tuple = ck::Tuple<F16>;
|
||||
using F16_F16_Tuple = ck::Tuple<F16, F16>;
|
||||
|
||||
using F32_TUPLE = ck::Tuple<F32>;
|
||||
using F32_Tuple = ck::Tuple<F32>;
|
||||
|
||||
// GEMM layout
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using Row_Tuple = ck::Tuple<Row>;
|
||||
using Row_Row_Tuple = ck::Tuple<Row, Row>;
|
||||
|
||||
// Conv layout
|
||||
//
|
||||
using NWC = ck::tensor_layout::convolution::NWC;
|
||||
using NHWC = ck::tensor_layout::convolution::NHWC;
|
||||
using NDHWC = ck::tensor_layout::convolution::NDHWC;
|
||||
|
||||
using KXC = ck::tensor_layout::convolution::KXC;
|
||||
using KYXC = ck::tensor_layout::convolution::KYXC;
|
||||
using KZYXC = ck::tensor_layout::convolution::KZYXC;
|
||||
|
||||
using NWK = ck::tensor_layout::convolution::NWK;
|
||||
using NHWK = ck::tensor_layout::convolution::NHWK;
|
||||
using NDHWK = ck::tensor_layout::convolution::NDHWK;
|
||||
|
||||
//
|
||||
using GNWC = ck::tensor_layout::convolution::GNWC;
|
||||
using GNHWC = ck::tensor_layout::convolution::GNHWC;
|
||||
using GNDHWC = ck::tensor_layout::convolution::GNDHWC;
|
||||
|
||||
using GKXC = ck::tensor_layout::convolution::GKXC;
|
||||
using GKYXC = ck::tensor_layout::convolution::GKYXC;
|
||||
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
|
||||
|
||||
using GNWK = ck::tensor_layout::convolution::GNWK;
|
||||
using GNHWK = ck::tensor_layout::convolution::GNHWK;
|
||||
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
|
||||
|
||||
//
|
||||
using NWGC = ck::tensor_layout::convolution::NWGC;
|
||||
using NHWGC = ck::tensor_layout::convolution::NHWGC;
|
||||
using NDHWGC = ck::tensor_layout::convolution::NDHWGC;
|
||||
|
||||
using KXGC = ck::tensor_layout::convolution::KXGC;
|
||||
using KYXGC = ck::tensor_layout::convolution::KYXGC;
|
||||
using KZYXGC = ck::tensor_layout::convolution::KZYXGC;
|
||||
|
||||
using NWGK = ck::tensor_layout::convolution::NWGK;
|
||||
using NHWGK = ck::tensor_layout::convolution::NHWGK;
|
||||
using NDHWGK = ck::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
// pointwise functor
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
|
||||
@@ -25,7 +25,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn
|
||||
2,
|
||||
F32,
|
||||
F32,
|
||||
F32_TUPLE,
|
||||
F32_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
@@ -37,7 +37,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn
|
||||
2,
|
||||
F32,
|
||||
F32,
|
||||
F32_TUPLE,
|
||||
F32_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
@@ -49,7 +49,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn
|
||||
2,
|
||||
F32,
|
||||
F32,
|
||||
F32_TUPLE,
|
||||
F32_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
@@ -61,7 +61,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn
|
||||
2,
|
||||
F32,
|
||||
F32,
|
||||
F32_TUPLE,
|
||||
F32_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
|
||||
@@ -25,7 +25,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instanc
|
||||
2,
|
||||
F32,
|
||||
F32,
|
||||
EMPTY_TUPLE,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
@@ -37,7 +37,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instanc
|
||||
2,
|
||||
F32,
|
||||
F32,
|
||||
EMPTY_TUPLE,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
@@ -49,7 +49,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instanc
|
||||
2,
|
||||
F32,
|
||||
F32,
|
||||
EMPTY_TUPLE,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
@@ -61,7 +61,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc
|
||||
2,
|
||||
F32,
|
||||
F32,
|
||||
EMPTY_TUPLE,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
|
||||
@@ -0,0 +1,270 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv1d backward data
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<1,
|
||||
NWC,
|
||||
KXC,
|
||||
NWK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceConvBwdData<1, NWC, KXC, NWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceConvBwdData<1, NWC, KXC, NWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<1,
|
||||
NWC,
|
||||
KXC,
|
||||
NWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// conv2d backward data
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||
NHWC,
|
||||
KYXC,
|
||||
NHWK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||
NHWC,
|
||||
KYXC,
|
||||
NHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||
NHWC,
|
||||
KYXC,
|
||||
NHWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||
NHWC,
|
||||
KYXC,
|
||||
NHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// conv3d backward data
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<3,
|
||||
NDHWC,
|
||||
KZYXC,
|
||||
NDHWK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<3,
|
||||
NDHWC,
|
||||
KZYXC,
|
||||
NDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<3,
|
||||
NDHWC,
|
||||
KZYXC,
|
||||
NDHWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<3,
|
||||
NDHWC,
|
||||
KZYXC,
|
||||
NDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBwdData<
|
||||
NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>
|
||||
{
|
||||
using DeviceOp = DeviceConvBwdData<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, NWC> && is_same_v<WeiLayout, KXC> &&
|
||||
is_same_v<OutLayout, NWK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
|
||||
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
|
||||
is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,230 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv1d backward weight
|
||||
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
|
||||
NWC,
|
||||
KXC,
|
||||
NWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
|
||||
NWC,
|
||||
KXC,
|
||||
NWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdWeight<1,
|
||||
NWC,
|
||||
KXC,
|
||||
NWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// conv2d backward weight
|
||||
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
|
||||
NHWC,
|
||||
KYXC,
|
||||
NHWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
|
||||
NHWC,
|
||||
KYXC,
|
||||
NHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdWeight<2,
|
||||
NHWC,
|
||||
KYXC,
|
||||
NHWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// conv3d backward weight
|
||||
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
|
||||
NDHWC,
|
||||
KZYXC,
|
||||
NDHWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
|
||||
NDHWC,
|
||||
KZYXC,
|
||||
NDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdWeight<3,
|
||||
NDHWC,
|
||||
KZYXC,
|
||||
NDHWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBwdWeight<
|
||||
NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>
|
||||
{
|
||||
using DeviceOp = DeviceConvBwdWeight<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, NWC> && is_same_v<WeiLayout, KXC> &&
|
||||
is_same_v<OutLayout, NWK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
|
||||
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
|
||||
is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,128 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv2d forward
|
||||
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvFwd<2,
|
||||
NHWC,
|
||||
KYXC,
|
||||
NHWK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceConvFwd<2, NHWC, KYXC, NHWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvFwd<2,
|
||||
NHWC,
|
||||
KYXC,
|
||||
NHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceConvFwd<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>
|
||||
{
|
||||
using DeviceOp = DeviceConvFwd<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
|
||||
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
|
||||
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -19,49 +19,53 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
|
||||
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_TUPLE,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddAddFastGelu>>>&);
|
||||
|
||||
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
|
||||
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_TUPLE,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddAddFastGelu>>>&);
|
||||
|
||||
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
|
||||
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_TUPLE,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddAddFastGelu>>>&);
|
||||
|
||||
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
|
||||
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_TUPLE,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
@@ -70,7 +74,9 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instanc
|
||||
// GEMM + Add + Add + FastGelu
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DELayout,
|
||||
typename D0Layout,
|
||||
typename D1Layout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename D0DataType,
|
||||
@@ -79,7 +85,8 @@ template <typename ALayout,
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DELayout,
|
||||
ck::Tuple<D0Layout, D1Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType, D1DataType>,
|
||||
@@ -90,7 +97,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
DELayout,
|
||||
ck::Tuple<D0Layout, D1Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType, D1DataType>,
|
||||
@@ -108,27 +116,31 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<DELayout, Row>)
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<DELayout, Row>)
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<DELayout, Row>)
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<DELayout, Row>)
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,49 +19,53 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
|
||||
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_TUPLE,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
|
||||
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
|
||||
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
|
||||
Col,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_TUPLE,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
|
||||
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
|
||||
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_TUPLE,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
|
||||
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
|
||||
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Col,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_TUPLE,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
@@ -70,7 +74,8 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
|
||||
// GEMM + Bilinear
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DELayout,
|
||||
typename DLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DDataType,
|
||||
@@ -78,7 +83,8 @@ template <typename ALayout,
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DELayout,
|
||||
ck::Tuple<DLayout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<DDataType>,
|
||||
@@ -89,7 +95,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
DELayout,
|
||||
ck::Tuple<DLayout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<DDataType>,
|
||||
@@ -106,24 +113,28 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
is_same_v<DDataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<DELayout, Row>)
|
||||
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<DELayout, Row>)
|
||||
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<DELayout, Row>)
|
||||
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<DELayout, Row>)
|
||||
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,352 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// grouped conv1d forward, GNWC/GKXC/GNWK
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
Empty_Tuple,
|
||||
GNWK,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
Empty_Tuple,
|
||||
GNWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
Empty_Tuple,
|
||||
GNWK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
Empty_Tuple,
|
||||
GNWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// grouped conv2d forward, NHWGC/KYXGC/NHWGK
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_kyxgc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
NHWGC,
|
||||
KYXGC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
|
||||
NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
Empty_Tuple,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
Empty_Tuple,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
Empty_Tuple,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
Empty_Tuple,
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, GNWC> &&
|
||||
is_same_v<WeiLayout, GKXC> && is_same_v<OutLayout, GNWK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
|
||||
is_same_v<WeiLayout, KYXGC> && is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
// no instance
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_kyxgc_nhwgk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
// no instance
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
// no instance
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
|
||||
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -16,15 +16,14 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using DsType = Tuple<>;
|
||||
|
||||
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsType,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
@@ -33,10 +32,11 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
|
||||
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Col,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsType,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
@@ -45,10 +45,11 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
|
||||
void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsType,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
@@ -57,10 +58,11 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
|
||||
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
|
||||
Col,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsType,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
@@ -68,18 +70,18 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedGemm<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
Empty_Tuple,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
Empty_Tuple,
|
||||
EDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -87,10 +89,11 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
using DeviceOp = DeviceGroupedGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
Empty_Tuple,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
Empty_Tuple,
|
||||
EDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -104,22 +107,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
|
||||
@@ -13,7 +13,9 @@
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace utils {
|
||||
@@ -194,10 +196,3 @@ check_err(const std::vector<T>& out,
|
||||
|
||||
} // namespace utils
|
||||
} // namespace ck
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
std::copy(std::begin(v), std::end(v), std::ostream_iterator<T>(os, " "));
|
||||
return os;
|
||||
}
|
||||
|
||||
@@ -1,574 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/library/utility/op_instance_engine.hpp"
|
||||
#include "ck/library/host_tensor/device_memory.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr<element_wise::PassThrough,
|
||||
element_wise::PassThrough,
|
||||
element_wise::PassThrough>;
|
||||
namespace instance {
|
||||
|
||||
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
|
||||
} // namespace instance
|
||||
namespace instance {
|
||||
|
||||
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
|
||||
std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
|
||||
} // namespace instance
|
||||
namespace instance {
|
||||
|
||||
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
|
||||
} // namespace instance
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace utils {
|
||||
namespace conv {
|
||||
|
||||
using DeviceConvFwdNoOpPtr =
|
||||
ck::tensor_operation::device::DeviceConvFwdPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
/**
|
||||
* @brief Calculate number of FLOPs for Convolution
|
||||
*
|
||||
* @param[in] N Batch size.
|
||||
* @param[in] C Number of input channels.
|
||||
* @param[in] K Number of output channels.
|
||||
* @param[in] filter_spatial_lengths Filter spatial dimensions lengths.
|
||||
* @param[in] output_spatial_lengths Convolution output spatial dimensions
|
||||
* lengths.
|
||||
*
|
||||
* @return The number of flops.
|
||||
*/
|
||||
std::size_t get_flops(ck::index_t N,
|
||||
ck::index_t C,
|
||||
ck::index_t K,
|
||||
const std::vector<ck::index_t>& filter_spatial_lengths,
|
||||
const std::vector<ck::index_t>& output_spatial_lengths);
|
||||
|
||||
/**
|
||||
* @brief Calculate number of bytes read/write by convolution algorithm.
|
||||
*
|
||||
* @param[in] N Batch size.
|
||||
* @param[in] C Number of input channels.
|
||||
* @param[in] K Number of output channels.
|
||||
* @param[in] input_spatial_lengths Input spatial dimensions lengths.
|
||||
* @param[in] filter_spatial_lengths Filter spatial dimensions lengths.
|
||||
* @param[in] output_spatial_lengths Output spatial dimensions lengths
|
||||
*
|
||||
* @tparam InDataType Input tensor data type.
|
||||
* @tparam WeiDataType Weights tensor data type.
|
||||
* @tparam OutDataType Output tensor data type.
|
||||
*
|
||||
* @return The number of used bytes.
|
||||
*/
|
||||
template <typename InDataType = float,
|
||||
typename WeiDataType = InDataType,
|
||||
typename OutDataType = InDataType>
|
||||
std::size_t get_btype(ck::index_t N,
|
||||
ck::index_t C,
|
||||
ck::index_t K,
|
||||
const std::vector<ck::index_t>& input_spatial_lengths,
|
||||
const std::vector<ck::index_t>& filter_spatial_lengths,
|
||||
const std::vector<ck::index_t>& output_spatial_lengths)
|
||||
{
|
||||
// sizeof(InDataType) * (N * C * <input spatial lengths product>) +
|
||||
// sizeof(WeiDataType) * (K * C * <filter spatial lengths product>) +
|
||||
// sizeof(OutDataType) * (N * K * <output spatial lengths product>);
|
||||
return sizeof(InDataType) * (N * C *
|
||||
std::accumulate(std::begin(input_spatial_lengths),
|
||||
std::end(input_spatial_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>())) +
|
||||
sizeof(WeiDataType) * (K * C *
|
||||
std::accumulate(std::begin(filter_spatial_lengths),
|
||||
std::end(filter_spatial_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>())) +
|
||||
sizeof(OutDataType) * (N * K *
|
||||
std::accumulate(std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>()));
|
||||
}
|
||||
|
||||
struct ConvParams
|
||||
{
|
||||
ConvParams();
|
||||
ConvParams(ck::index_t n_dim,
|
||||
ck::index_t n_batch,
|
||||
ck::index_t n_out_channels,
|
||||
ck::index_t n_in_channels,
|
||||
const std::vector<ck::index_t>& filters_len,
|
||||
const std::vector<ck::index_t>& input_len,
|
||||
const std::vector<ck::index_t>& strides,
|
||||
const std::vector<ck::index_t>& dilations,
|
||||
const std::vector<ck::index_t>& left_pads,
|
||||
const std::vector<ck::index_t>& right_pads);
|
||||
|
||||
ck::index_t num_dim_spatial_;
|
||||
ck::index_t N_;
|
||||
ck::index_t K_;
|
||||
ck::index_t C_;
|
||||
|
||||
std::vector<ck::index_t> filter_spatial_lengths_;
|
||||
std::vector<ck::index_t> input_spatial_lengths_;
|
||||
|
||||
std::vector<ck::index_t> conv_filter_strides_;
|
||||
std::vector<ck::index_t> conv_filter_dilations_;
|
||||
|
||||
std::vector<ck::index_t> input_left_pads_;
|
||||
std::vector<ck::index_t> input_right_pads_;
|
||||
|
||||
std::vector<ck::index_t> GetOutputSpatialLengths() const;
|
||||
};
|
||||
|
||||
ConvParams parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[]);
|
||||
|
||||
/**
|
||||
* @brief Gets the host tensor descriptor.
|
||||
*
|
||||
* @param[in] dims The tensor dimensions lengths. Always in NCHW format.
|
||||
* @param[in] layout The tensor data layout.
|
||||
*
|
||||
* @tparam TensorLayout Layout type.
|
||||
*
|
||||
* @return The host tensor descriptor object.
|
||||
*/
|
||||
template <typename TensorLayout>
|
||||
HostTensorDescriptor get_host_tensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
const TensorLayout& layout)
|
||||
{
|
||||
std::size_t C = dims[1];
|
||||
// 1D
|
||||
if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NCW>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KCX>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NKW>::value)
|
||||
{
|
||||
|
||||
return HostTensorDescriptor(dims, std::vector<std::size_t>{C * dims[2], dims[2], 1});
|
||||
}
|
||||
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NWC>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KXC>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NWK>::value)
|
||||
{
|
||||
return HostTensorDescriptor(dims, std::vector<std::size_t>{C * dims[2], 1, C});
|
||||
}
|
||||
// 2D
|
||||
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NCHW>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KCYX>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NKHW>::value)
|
||||
{
|
||||
|
||||
return HostTensorDescriptor(
|
||||
dims, std::vector<std::size_t>{C * dims[2] * dims[3], dims[2] * dims[3], dims[3], 1});
|
||||
}
|
||||
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NHWC>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KYXC>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NHWK>::value)
|
||||
{
|
||||
return HostTensorDescriptor(
|
||||
dims, std::vector<std::size_t>{C * dims[2] * dims[3], 1, dims[3] * C, C});
|
||||
}
|
||||
// 3D
|
||||
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NCDHW>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KCZYX>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NKDHW>::value)
|
||||
{
|
||||
|
||||
return HostTensorDescriptor(dims,
|
||||
std::vector<std::size_t>{C * dims[2] * dims[3] * dims[4],
|
||||
dims[2] * dims[3] * dims[4],
|
||||
dims[3] * dims[4],
|
||||
dims[4],
|
||||
1});
|
||||
}
|
||||
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NDHWC>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KZYXC>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NDHWK>::value)
|
||||
{
|
||||
return HostTensorDescriptor(
|
||||
dims,
|
||||
std::vector<std::size_t>{
|
||||
C * dims[2] * dims[3] * dims[4], 1, C * dims[3] * dims[4], C * dims[4], C});
|
||||
}
|
||||
|
||||
std::stringstream err_msg;
|
||||
err_msg << "Unsupported data layout provided: " << layout << "!";
|
||||
throw std::runtime_error(err_msg.str());
|
||||
}
|
||||
|
||||
HostTensorDescriptor get_output_host_tensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2);
|
||||
|
||||
HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2);
|
||||
|
||||
HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2);
|
||||
|
||||
template <ck::index_t NDim,
|
||||
typename InDataType = float,
|
||||
typename WeiDataType = float,
|
||||
typename OutDataType = float>
|
||||
void run_reference_convolution_forward(const ConvParams& params,
|
||||
const Tensor<InDataType>& input,
|
||||
const Tensor<WeiDataType>& weights,
|
||||
Tensor<OutDataType>& output)
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
NDim>();
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_argument = ref_conv.MakeArgument(input,
|
||||
weights,
|
||||
output,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
struct ConvolutionFwdInstances;
|
||||
|
||||
template <>
|
||||
struct ConvolutionFwdInstances<float, float, float>
|
||||
{
|
||||
template <int NumDimSpatial,
|
||||
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
|
||||
static std::vector<DeviceConvFwdNoOpPtr> Get()
|
||||
{
|
||||
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
|
||||
if constexpr(NumDimSpatial == 1)
|
||||
{
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs);
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2)
|
||||
{
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs);
|
||||
}
|
||||
return conv_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvolutionFwdInstances<half_t, half_t, half_t>
|
||||
{
|
||||
template <int NumDimSpatial,
|
||||
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
|
||||
static std::vector<DeviceConvFwdNoOpPtr> Get()
|
||||
{
|
||||
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
|
||||
if constexpr(NumDimSpatial == 1)
|
||||
{
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs);
|
||||
return conv_ptrs;
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2)
|
||||
{
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs);
|
||||
}
|
||||
return conv_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvolutionFwdInstances<bhalf_t, bhalf_t, bhalf_t>
|
||||
{
|
||||
template <int NumDimSpatial,
|
||||
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
|
||||
static std::vector<DeviceConvFwdNoOpPtr> Get()
|
||||
{
|
||||
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
|
||||
if constexpr(NumDimSpatial == 1)
|
||||
{
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs);
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2)
|
||||
{
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs);
|
||||
}
|
||||
return conv_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvolutionFwdInstances<int8_t, int8_t, int8_t>
|
||||
{
|
||||
template <int NumDimSpatial,
|
||||
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
|
||||
static std::vector<DeviceConvFwdNoOpPtr> Get()
|
||||
{
|
||||
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
|
||||
if constexpr(NumDimSpatial == 1)
|
||||
{
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs);
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2)
|
||||
{
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs);
|
||||
}
|
||||
return conv_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout = ck::tensor_layout::convolution::NHWC,
|
||||
typename WeiLayout = ck::tensor_layout::convolution::KYXC,
|
||||
typename OutLayout = ck::tensor_layout::convolution::NHWK,
|
||||
typename InElementwiseOp = ck::tensor_operation::element_wise::PassThrough,
|
||||
typename WeiElementwiseOp = ck::tensor_operation::element_wise::PassThrough,
|
||||
typename OutElementwiseOp = ck::tensor_operation::element_wise::PassThrough,
|
||||
typename InputInitFun = FillUniformDistribution<InDataType>,
|
||||
typename WeightsInitFun = FillUniformDistribution<WeiDataType>>
|
||||
class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType, WeiDataType>
|
||||
{
|
||||
using DeviceConvFwdOp = tensor_operation::device::
|
||||
DeviceConvFwd<InElementwiseOp, WeiElementwiseOp, OutElementwiseOp>;
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
using DeviceBuffers = std::vector<DeviceMemPtr>;
|
||||
using BaseType = ck::utils::OpInstance<OutDataType, InDataType, WeiDataType>;
|
||||
template <typename T>
|
||||
using TensorPtr = std::unique_ptr<Tensor<T>>;
|
||||
using InTensorsTuple = std::tuple<TensorPtr<InDataType>, TensorPtr<WeiDataType>>;
|
||||
|
||||
public:
|
||||
ConvFwdOpInstance() = delete;
|
||||
ConvFwdOpInstance(const ConvFwdOpInstance&) = default;
|
||||
ConvFwdOpInstance& operator=(const ConvFwdOpInstance&) = default;
|
||||
|
||||
ConvFwdOpInstance(const ConvParams& params,
|
||||
bool do_init = true,
|
||||
const InputInitFun& input_init_f = InputInitFun(),
|
||||
const WeightsInitFun& weights_init_f = WeightsInitFun())
|
||||
: BaseType(),
|
||||
params_{params},
|
||||
output_spatial_lengths_{params.GetOutputSpatialLengths()},
|
||||
do_init_{do_init},
|
||||
input_init_f_{input_init_f},
|
||||
weights_init_f_{weights_init_f}
|
||||
{
|
||||
}
|
||||
|
||||
virtual ~ConvFwdOpInstance() override{};
|
||||
|
||||
virtual InTensorsTuple GetInputTensors() const override
|
||||
{
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params_.N_),
|
||||
static_cast<std::size_t>(params_.C_)};
|
||||
input_dims.insert(std::end(input_dims),
|
||||
std::begin(params_.input_spatial_lengths_),
|
||||
std::end(params_.input_spatial_lengths_));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params_.K_),
|
||||
static_cast<std::size_t>(params_.C_)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(params_.filter_spatial_lengths_),
|
||||
std::end(params_.filter_spatial_lengths_));
|
||||
|
||||
auto input = std::make_unique<Tensor<InDataType>>(
|
||||
get_host_tensor_descriptor(input_dims, InLayout{}));
|
||||
auto weights = std::make_unique<Tensor<WeiDataType>>(
|
||||
get_host_tensor_descriptor(filter_dims, WeiLayout{}));
|
||||
|
||||
if(do_init_)
|
||||
{
|
||||
input_init_f_(input->begin(), input->end());
|
||||
weights_init_f_(weights->begin(), weights->end());
|
||||
}
|
||||
|
||||
return std::make_tuple(std::move(input), std::move(weights));
|
||||
}
|
||||
|
||||
virtual TensorPtr<OutDataType> GetOutputTensor() const override
|
||||
{
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params_.N_),
|
||||
static_cast<std::size_t>(params_.K_)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths_),
|
||||
std::end(output_spatial_lengths_));
|
||||
auto output = std::make_unique<Tensor<OutDataType>>(
|
||||
get_host_tensor_descriptor(output_dims, OutLayout{}));
|
||||
|
||||
if(do_init_)
|
||||
{
|
||||
std::fill(output->begin(), output->end(), OutDataType(0.f));
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
virtual std::unique_ptr<tensor_operation::device::BaseInvoker>
|
||||
MakeInvokerPointer(tensor_operation::device::BaseOperator* op_ptr) const override
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<InElementwiseOp, ck::tensor_operation::element_wise::PassThrough>);
|
||||
static_assert(
|
||||
std::is_same_v<OutElementwiseOp, ck::tensor_operation::element_wise::PassThrough>);
|
||||
static_assert(
|
||||
std::is_same_v<WeiElementwiseOp, ck::tensor_operation::element_wise::PassThrough>);
|
||||
|
||||
auto conv_ptr = dynamic_cast<DeviceConvFwdOp*>(op_ptr);
|
||||
if(!conv_ptr)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"[ConvFwdOpInstance]: couldn't cast op_ptr to DeviceConvFwdNoOpPtr type!");
|
||||
}
|
||||
return conv_ptr->MakeInvokerPointer();
|
||||
}
|
||||
|
||||
virtual std::unique_ptr<tensor_operation::device::BaseArgument>
|
||||
MakeArgumentPointer(tensor_operation::device::BaseOperator* op_ptr,
|
||||
const DeviceBuffers& in_device_buffers,
|
||||
const DeviceMemPtr& out_device_buffer) const override
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<InElementwiseOp, ck::tensor_operation::element_wise::PassThrough>);
|
||||
static_assert(
|
||||
std::is_same_v<OutElementwiseOp, ck::tensor_operation::element_wise::PassThrough>);
|
||||
static_assert(
|
||||
std::is_same_v<WeiElementwiseOp, ck::tensor_operation::element_wise::PassThrough>);
|
||||
|
||||
auto conv_ptr = dynamic_cast<DeviceConvFwdOp*>(op_ptr);
|
||||
if(!conv_ptr)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"[ConvFwdOpInstance]: couldn't cast op_ptr to DeviceConvFwdNoOpPtr type!");
|
||||
}
|
||||
|
||||
return conv_ptr->MakeArgumentPointer(
|
||||
static_cast<InDataType*>(in_device_buffers[0]->GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(in_device_buffers[1]->GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buffer->GetDeviceBuffer()),
|
||||
params_.N_,
|
||||
params_.K_,
|
||||
params_.C_,
|
||||
params_.input_spatial_lengths_,
|
||||
params_.filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
params_.conv_filter_strides_,
|
||||
params_.conv_filter_dilations_,
|
||||
params_.input_left_pads_,
|
||||
params_.input_right_pads_,
|
||||
InElementwiseOp{},
|
||||
WeiElementwiseOp{},
|
||||
OutElementwiseOp{});
|
||||
}
|
||||
|
||||
virtual std::size_t GetFlops() const override
|
||||
{
|
||||
return get_flops(params_.N_,
|
||||
params_.C_,
|
||||
params_.K_,
|
||||
params_.filter_spatial_lengths_,
|
||||
output_spatial_lengths_);
|
||||
}
|
||||
|
||||
virtual std::size_t GetBtype() const override
|
||||
{
|
||||
return get_btype<InDataType, WeiDataType, OutDataType>(params_.N_,
|
||||
params_.C_,
|
||||
params_.K_,
|
||||
params_.input_spatial_lengths_,
|
||||
params_.filter_spatial_lengths_,
|
||||
output_spatial_lengths_);
|
||||
}
|
||||
|
||||
private:
|
||||
const ConvParams& params_;
|
||||
const std::vector<ck::index_t> output_spatial_lengths_;
|
||||
const bool do_init_;
|
||||
InputInitFun input_init_f_;
|
||||
WeightsInitFun weights_init_f_;
|
||||
};
|
||||
|
||||
} // namespace conv
|
||||
} // namespace utils
|
||||
} // namespace ck
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParams& p);
|
||||
@@ -0,0 +1,354 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace utils {
|
||||
namespace conv {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename OldLayout>
|
||||
std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
|
||||
{
|
||||
// HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function,
|
||||
// is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK
|
||||
// TODO: remove this branch after removing legacy kernel
|
||||
if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NWC> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KXC> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NWK>)
|
||||
{
|
||||
return {0, 1, 3, 2};
|
||||
}
|
||||
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NHWC> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KYXC> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NHWK>)
|
||||
{
|
||||
return {0, 1, 4, 2, 3};
|
||||
}
|
||||
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NDHWC> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KZYXC> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NDHWK>)
|
||||
{
|
||||
return {0, 1, 5, 2, 3, 4};
|
||||
}
|
||||
// separate from legacy code above
|
||||
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNCW> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKCX> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNKW>)
|
||||
{
|
||||
return {0, 1, 2, 3};
|
||||
}
|
||||
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNCHW> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKCYX> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNKHW>)
|
||||
{
|
||||
return {0, 1, 2, 3, 4};
|
||||
}
|
||||
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNCDHW> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKCZYX> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNKDHW>)
|
||||
{
|
||||
return {0, 1, 2, 3, 4, 5};
|
||||
}
|
||||
if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNWC> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKXC> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNWK>)
|
||||
{
|
||||
return {0, 1, 3, 2};
|
||||
}
|
||||
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNHWC> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKYXC> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNHWK>)
|
||||
{
|
||||
return {0, 1, 4, 2, 3};
|
||||
}
|
||||
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNDHWC> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKZYXC> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNDHWK>)
|
||||
{
|
||||
return {0, 1, 5, 2, 3, 4};
|
||||
}
|
||||
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NWGC> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KXGC> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NWGK>)
|
||||
{
|
||||
return {2, 0, 3, 1};
|
||||
}
|
||||
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NHWGC> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KYXGC> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NHWGK>)
|
||||
{
|
||||
return {3, 0, 4, 1, 2};
|
||||
}
|
||||
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NDHWGC> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KZYXGC> ||
|
||||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NDHWGK>)
|
||||
{
|
||||
return {4, 0, 5, 1, 2, 3};
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%s\n", __func__);
|
||||
throw std::runtime_error("wrong! unsupported layout");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// make tensor descriptor for packed input tensor, and order the dimension in the order of GNCHW
|
||||
// regardless of physical layout
|
||||
template <typename InLayout>
|
||||
HostTensorDescriptor
|
||||
make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck::utils::conv::ConvParam& param)
|
||||
{
|
||||
std::vector<std::size_t> physical_lengths;
|
||||
|
||||
// HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function,
|
||||
// is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK
|
||||
// TODO: remove this branch after removing legacy kernel
|
||||
if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWC> ||
|
||||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NHWC> ||
|
||||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
|
||||
{
|
||||
if(param.G_ != 1)
|
||||
{
|
||||
throw std::runtime_error("wrong! G != 1");
|
||||
}
|
||||
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 2,
|
||||
param.input_spatial_lengths_.begin(),
|
||||
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
// separate from legacy code above
|
||||
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCW> ||
|
||||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCHW> ||
|
||||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCDHW>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.end(),
|
||||
param.input_spatial_lengths_.begin(),
|
||||
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNWC> ||
|
||||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNHWC> ||
|
||||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNDHWC>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 2,
|
||||
param.input_spatial_lengths_.begin(),
|
||||
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWGC> ||
|
||||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NHWGC> ||
|
||||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NDHWGC>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 1,
|
||||
param.input_spatial_lengths_.begin(),
|
||||
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%s\n", __func__);
|
||||
printf("%s\n", InLayout::name);
|
||||
throw std::runtime_error("wrong! unsupported layout");
|
||||
}
|
||||
|
||||
return transpose_host_tensor_descriptor_given_new2old(
|
||||
HostTensorDescriptor(physical_lengths),
|
||||
detail::get_layout_transpose_gnchw_to_old<InLayout>());
|
||||
}
|
||||
|
||||
// make tensor descriptor for packed weight tensor, and order the dimension in the order of GKCYX
|
||||
// regardless of physical layout
|
||||
template <typename WeiLayout>
|
||||
HostTensorDescriptor
|
||||
make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck::utils::conv::ConvParam& param)
|
||||
{
|
||||
std::vector<std::size_t> physical_lengths;
|
||||
|
||||
// HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function,
|
||||
// is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK
|
||||
// TODO: remove this branch after removing legacy kernel
|
||||
if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC> ||
|
||||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC> ||
|
||||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
|
||||
{
|
||||
if(param.G_ != 1)
|
||||
{
|
||||
throw std::runtime_error("wrong! G != 1");
|
||||
}
|
||||
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.K_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 2,
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
// separate from legacy code above
|
||||
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC> ||
|
||||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC> ||
|
||||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
|
||||
{
|
||||
if(param.G_ != 1)
|
||||
{
|
||||
throw std::runtime_error("wrong! G != 1");
|
||||
}
|
||||
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.end(),
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::GKCX> ||
|
||||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::GKCYX> ||
|
||||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::GKCZYX>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.K_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.end(),
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::GKXC> ||
|
||||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::GKYXC> ||
|
||||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::GKZYXC>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.K_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 2,
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXGC> ||
|
||||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXGC> ||
|
||||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXGC>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
|
||||
static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 1,
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%s\n", __func__);
|
||||
printf("%s\n", WeiLayout::name);
|
||||
throw std::runtime_error("wrong! unsupported layout");
|
||||
}
|
||||
|
||||
return transpose_host_tensor_descriptor_given_new2old(
|
||||
HostTensorDescriptor(physical_lengths),
|
||||
detail::get_layout_transpose_gnchw_to_old<WeiLayout>());
|
||||
}
|
||||
|
||||
// make tensor descriptor for packed output tensor, and order the dimension in the order of GNKHW
|
||||
// regardless of physical layout
|
||||
template <typename OutLayout>
|
||||
HostTensorDescriptor
|
||||
make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck::utils::conv::ConvParam& param)
|
||||
{
|
||||
std::vector<std::size_t> physical_lengths;
|
||||
|
||||
// HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function,
|
||||
// is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK
|
||||
// TODO: remove this branch after removing legacy kernel
|
||||
if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NWK> ||
|
||||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK> ||
|
||||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
|
||||
{
|
||||
if(param.G_ != 1)
|
||||
{
|
||||
throw std::runtime_error("wrong! G != 1");
|
||||
}
|
||||
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.K_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 2,
|
||||
param.output_spatial_lengths_.begin(),
|
||||
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
// separate from legacy code above
|
||||
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNKW> ||
|
||||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNKHW> ||
|
||||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNKDHW>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.K_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.end(),
|
||||
param.output_spatial_lengths_.begin(),
|
||||
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNWK> ||
|
||||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNHWK> ||
|
||||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNDHWK>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.K_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 2,
|
||||
param.output_spatial_lengths_.begin(),
|
||||
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NWGK> ||
|
||||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NHWGK> ||
|
||||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWGK>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.K_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 1,
|
||||
param.output_spatial_lengths_.begin(),
|
||||
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%s\n", __func__);
|
||||
printf("%s\n", OutLayout::name);
|
||||
throw std::runtime_error("wrong! unsupported layout");
|
||||
}
|
||||
|
||||
return transpose_host_tensor_descriptor_given_new2old(
|
||||
HostTensorDescriptor(physical_lengths),
|
||||
detail::get_layout_transpose_gnchw_to_old<OutLayout>());
|
||||
}
|
||||
|
||||
} // namespace conv
|
||||
} // namespace utils
|
||||
} // namespace ck
|
||||
86
library/include/ck/library/utility/convolution_parameter.hpp
Normal file
86
library/include/ck/library/utility/convolution_parameter.hpp
Normal file
@@ -0,0 +1,86 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <numeric>
|
||||
#include <iterator>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace utils {
|
||||
namespace conv {
|
||||
|
||||
struct ConvParam
|
||||
{
|
||||
ConvParam();
|
||||
ConvParam(ck::index_t n_dim,
|
||||
ck::index_t group_count,
|
||||
ck::index_t n_batch,
|
||||
ck::index_t n_out_channels,
|
||||
ck::index_t n_in_channels,
|
||||
const std::vector<ck::index_t>& filters_len,
|
||||
const std::vector<ck::index_t>& input_len,
|
||||
const std::vector<ck::index_t>& strides,
|
||||
const std::vector<ck::index_t>& dilations,
|
||||
const std::vector<ck::index_t>& left_pads,
|
||||
const std::vector<ck::index_t>& right_pads);
|
||||
|
||||
ck::index_t num_dim_spatial_;
|
||||
ck::index_t G_;
|
||||
ck::index_t N_;
|
||||
ck::index_t K_;
|
||||
ck::index_t C_;
|
||||
|
||||
std::vector<ck::index_t> filter_spatial_lengths_;
|
||||
std::vector<ck::index_t> input_spatial_lengths_;
|
||||
std::vector<ck::index_t> output_spatial_lengths_;
|
||||
|
||||
std::vector<ck::index_t> conv_filter_strides_;
|
||||
std::vector<ck::index_t> conv_filter_dilations_;
|
||||
|
||||
std::vector<ck::index_t> input_left_pads_;
|
||||
std::vector<ck::index_t> input_right_pads_;
|
||||
|
||||
std::vector<ck::index_t> GetOutputSpatialLengths() const;
|
||||
|
||||
std::size_t GetFlops() const;
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
std::size_t GetByte() const
|
||||
{
|
||||
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
|
||||
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
|
||||
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
|
||||
return sizeof(InDataType) *
|
||||
(G_ * N_ * C_ *
|
||||
std::accumulate(std::begin(input_spatial_lengths_),
|
||||
std::begin(input_spatial_lengths_) + num_dim_spatial_,
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>())) +
|
||||
sizeof(WeiDataType) *
|
||||
(G_ * K_ * C_ *
|
||||
std::accumulate(std::begin(filter_spatial_lengths_),
|
||||
std::begin(filter_spatial_lengths_) + num_dim_spatial_,
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>())) +
|
||||
sizeof(OutDataType) * (G_ * N_ * K_ *
|
||||
std::accumulate(std::begin(output_spatial_lengths_),
|
||||
std::end(output_spatial_lengths_),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>()));
|
||||
}
|
||||
};
|
||||
|
||||
std::string get_conv_param_parser_helper_msg();
|
||||
|
||||
ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[]);
|
||||
|
||||
} // namespace conv
|
||||
} // namespace utils
|
||||
} // namespace ck
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p);
|
||||
@@ -11,8 +11,8 @@
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
#include "ck/utility/reduction_common.hpp"
|
||||
#include "ck/utility/reduction_functions_accumulate.hpp"
|
||||
#include "ck/library/host_tensor/host_common_util.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
template <int NDim>
|
||||
static void get_all_indexes(const std::array<size_t, NDim>& dimLengths,
|
||||
@@ -73,22 +73,41 @@ auto construct_f_unpack_args(F, T args)
|
||||
|
||||
struct HostTensorDescriptor
|
||||
{
|
||||
HostTensorDescriptor() = delete;
|
||||
|
||||
template <typename X>
|
||||
HostTensorDescriptor(const std::vector<X>& lens);
|
||||
|
||||
template <typename X, typename Y>
|
||||
HostTensorDescriptor(const std::vector<X>& lens, const std::vector<Y>& strides);
|
||||
HostTensorDescriptor() = default;
|
||||
|
||||
void CalculateStrides();
|
||||
|
||||
template <typename X>
|
||||
HostTensorDescriptor(const std::initializer_list<X>& lens) : mLens(lens.begin(), lens.end())
|
||||
{
|
||||
this->CalculateStrides();
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
HostTensorDescriptor(const std::vector<X>& lens) : mLens(lens.begin(), lens.end())
|
||||
{
|
||||
this->CalculateStrides();
|
||||
}
|
||||
|
||||
template <typename Range>
|
||||
HostTensorDescriptor(const Range& lens) : mLens(lens.begin(), lens.end())
|
||||
{
|
||||
this->CalculateStrides();
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
HostTensorDescriptor(const std::initializer_list<X>& lens,
|
||||
const std::initializer_list<Y>& strides)
|
||||
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
|
||||
{
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
HostTensorDescriptor(const std::vector<X>& lens, const std::vector<Y>& strides)
|
||||
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Range1, typename Range2>
|
||||
HostTensorDescriptor(const Range1& lens, const Range2& strides)
|
||||
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
|
||||
@@ -97,7 +116,7 @@ struct HostTensorDescriptor
|
||||
|
||||
std::size_t GetNumOfDimension() const;
|
||||
std::size_t GetElementSize() const;
|
||||
std::size_t GetElementSpace() const;
|
||||
std::size_t GetElementSpaceSize() const;
|
||||
|
||||
const std::vector<std::size_t>& GetLengths() const;
|
||||
const std::vector<std::size_t>& GetStrides() const;
|
||||
@@ -122,6 +141,22 @@ struct HostTensorDescriptor
|
||||
std::vector<std::size_t> mStrides;
|
||||
};
|
||||
|
||||
template <typename New2Old>
|
||||
HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor& a,
|
||||
const New2Old& new2old)
|
||||
{
|
||||
std::vector<std::size_t> new_lengths(a.GetNumOfDimension());
|
||||
std::vector<std::size_t> new_strides(a.GetNumOfDimension());
|
||||
|
||||
for(std::size_t i = 0; i < a.GetNumOfDimension(); i++)
|
||||
{
|
||||
new_lengths[i] = a.GetLengths()[new2old[i]];
|
||||
new_strides[i] = a.GetStrides()[new2old[i]];
|
||||
}
|
||||
|
||||
return HostTensorDescriptor(new_lengths, new_strides);
|
||||
}
|
||||
|
||||
struct joinable_thread : std::thread
|
||||
{
|
||||
template <typename... Xs>
|
||||
@@ -203,22 +238,22 @@ template <typename T>
|
||||
struct Tensor
|
||||
{
|
||||
template <typename X>
|
||||
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpace())
|
||||
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
|
||||
{
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
Tensor(std::vector<X> lens) : mDesc(lens), mData(mDesc.GetElementSpace())
|
||||
Tensor(std::vector<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
|
||||
{
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
Tensor(std::vector<X> lens, std::vector<Y> strides)
|
||||
: mDesc(lens, strides), mData(mDesc.GetElementSpace())
|
||||
: mDesc(lens, strides), mData(mDesc.GetElementSpaceSize())
|
||||
{
|
||||
}
|
||||
|
||||
Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {}
|
||||
Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {}
|
||||
|
||||
template <typename OutT>
|
||||
Tensor<OutT> CopyAsType()
|
||||
@@ -240,6 +275,24 @@ struct Tensor
|
||||
return *this;
|
||||
}
|
||||
|
||||
const std::vector<std::size_t>& GetLengths() const { return mDesc.GetLengths(); }
|
||||
|
||||
const std::vector<std::size_t>& GetStrides() const { return mDesc.GetStrides(); }
|
||||
|
||||
std::size_t GetNumOfDimension() const { return mDesc.GetNumOfDimension(); }
|
||||
|
||||
std::size_t GetElementSize() const { return mDesc.GetElementSize(); }
|
||||
|
||||
std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); }
|
||||
|
||||
void SetZero()
|
||||
{
|
||||
for(auto& v : mData)
|
||||
{
|
||||
v = T{0};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
|
||||
{
|
||||
@@ -330,6 +383,19 @@ struct Tensor
|
||||
mDesc.GetLengths()[4])(num_thread);
|
||||
break;
|
||||
}
|
||||
case 6: {
|
||||
auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4, auto i5) {
|
||||
(*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4, i5);
|
||||
};
|
||||
make_ParallelTensorFunctor(f,
|
||||
mDesc.GetLengths()[0],
|
||||
mDesc.GetLengths()[1],
|
||||
mDesc.GetLengths()[2],
|
||||
mDesc.GetLengths()[3],
|
||||
mDesc.GetLengths()[4],
|
||||
mDesc.GetLengths()[5])(num_thread);
|
||||
break;
|
||||
}
|
||||
default: throw std::runtime_error("unspported dimension");
|
||||
}
|
||||
}
|
||||
@@ -367,17 +433,3 @@ struct Tensor
|
||||
HostTensorDescriptor mDesc;
|
||||
std::vector<T> mData;
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens)
|
||||
: mLens(lens.begin(), lens.end())
|
||||
{
|
||||
this->CalculateStrides();
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens,
|
||||
const std::vector<Y>& strides)
|
||||
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
|
||||
{
|
||||
}
|
||||
@@ -16,8 +16,8 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/host_tensor/device_memory.hpp"
|
||||
#include "ck/library/host_tensor/host_tensor.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace utils {
|
||||
@@ -103,8 +103,8 @@ class OpInstanceRunEngine
|
||||
}
|
||||
}
|
||||
AllocateDeviceInputTensors(std::make_index_sequence<kNInArgs_>{});
|
||||
out_device_buffer_ =
|
||||
std::make_unique<DeviceMem>(sizeof(OutDataType) * out_tensor_->mDesc.GetElementSpace());
|
||||
out_device_buffer_ = std::make_unique<DeviceMem>(sizeof(OutDataType) *
|
||||
out_tensor_->mDesc.GetElementSpaceSize());
|
||||
out_device_buffer_->SetZero();
|
||||
}
|
||||
|
||||
@@ -222,7 +222,7 @@ class OpInstanceRunEngine
|
||||
in_device_buffers_
|
||||
.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(std::tuple_element_t<Index, InArgsTypesTuple>) *
|
||||
ts->mDesc.GetElementSpace()))
|
||||
ts->mDesc.GetElementSpaceSize()))
|
||||
->ToDevice(ts->mData.data());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user