Add support for groups in Img2Col/Col2Img (#1007)

* Add support for groups in Img2Col/Col2Img

* Fix interface test

* Fix interface test G to N

* Improve performance

* Change gemm layout to 3d

* Fixes

[ROCm/composable_kernel commit: 2e824c6d46]
This commit is contained in:
Bartłomiej Kocot
2023-10-31 10:46:32 +01:00
committed by GitHub
parent 1a389f2d2e
commit 60a0f176d3
30 changed files with 1114 additions and 281 deletions

View File

@@ -19,9 +19,7 @@ namespace host {
* \brief Reference implementation for column to image.
*
* Input tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Memory layout is the same.
* Output tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C].
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Image Layout.
@@ -95,18 +93,19 @@ struct ReferenceColumnToImage : public device::BaseOperator
float Run(const Argument& arg)
{
if(!(arg.output_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.input_.GetNumOfDimension() == 2))
arg.input_.GetNumOfDimension() == 3))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
const index_t G = arg.output_.GetLengths()[0];
const index_t N = arg.output_.GetLengths()[1];
const index_t C = arg.output_.GetLengths()[2];
if constexpr(NDimSpatial == 1)
{
const index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto n) {
auto func = [&](auto g, auto n) {
for(index_t wo = 0; wo < Wo; ++wo)
{
index_t row = n * Wo + wo;
@@ -123,9 +122,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3])
{
float v_in = ck::type_convert<float>(arg.input_(row, column));
float v_out = ck::type_convert<float>(arg.output_(0, n, c, wi));
arg.output_(0, n, c, wi) =
float v_in =
ck::type_convert<float>(arg.input_(g, row, column));
float v_out = ck::type_convert<float>(arg.output_(g, n, c, wi));
arg.output_(g, n, c, wi) =
ck::type_convert<OutDataType>(v_in + v_out);
}
column++;
@@ -134,7 +134,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
};
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
return 0;
}
@@ -143,7 +143,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
const index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto n) {
auto func = [&](auto g, auto n) {
for(index_t ho = 0; ho < Ho; ++ho)
{
for(index_t wo = 0; wo < Wo; ++wo)
@@ -176,10 +176,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
arg.output_.GetLengths()[4])
{
float v_in =
ck::type_convert<float>(arg.input_(row, column));
ck::type_convert<float>(arg.input_(g, row, column));
float v_out = ck::type_convert<float>(
arg.output_(0, n, c, hi, wi));
arg.output_(0, n, c, hi, wi) =
arg.output_(g, n, c, hi, wi));
arg.output_(g, n, c, hi, wi) =
ck::type_convert<OutDataType>(v_in + v_out);
}
column++;
@@ -190,7 +190,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
};
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
return 0;
}
@@ -200,7 +200,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
const index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto n) {
auto func = [&](auto g, auto n) {
for(index_t d_o = 0; d_o < Do; ++d_o)
{
for(index_t ho = 0; ho < Ho; ++ho)
@@ -245,10 +245,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
arg.output_.GetLengths()[5])
{
float v_in = ck::type_convert<float>(
arg.input_(row, column));
arg.input_(g, row, column));
float v_out = ck::type_convert<float>(
arg.output_(0, n, c, di, hi, wi));
arg.output_(0, n, c, di, hi, wi) =
arg.output_(g, n, c, di, hi, wi));
arg.output_(g, n, c, di, hi, wi) =
ck::type_convert<OutDataType>(v_in + v_out);
}
column++;
@@ -261,7 +261,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
}
};
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
return 0;
}
@@ -303,8 +303,9 @@ struct ReferenceColumnToImage : public device::BaseOperator
C * ck::accumulate_n<index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(NDoHoWo) &&
arg.input_.GetLengths()[1] == static_cast<std::size_t>(CZYX)))
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(G) &&
arg.input_.GetLengths()[1] == static_cast<std::size_t>(NDoHoWo) &&
arg.input_.GetLengths()[2] == static_cast<std::size_t>(CZYX)))
{
return false;
}

View File

@@ -19,9 +19,7 @@ namespace host {
* \brief Reference implementation for image to column.
*
* Input tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C].
* Output tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Memory layout is the same.
* Output tensor descriptor has [G * N * Do * Ho * Wo, Z * Y * X * C] data layout.
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Image Layout.
@@ -95,18 +93,19 @@ struct ReferenceImageToColumn : public device::BaseOperator
float Run(const Argument& arg)
{
if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == 2))
arg.output_.GetNumOfDimension() == 3))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
const index_t G = arg.input_.GetLengths()[0];
const index_t N = arg.input_.GetLengths()[1];
const index_t C = arg.input_.GetLengths()[2];
if constexpr(NDimSpatial == 1)
{
const index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto n, auto wo) {
auto func = [&](auto g, auto n, auto wo) {
index_t row = n * Wo + wo;
index_t column = 0;
@@ -121,15 +120,15 @@ struct ReferenceImageToColumn : public device::BaseOperator
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{
InDataType v_in = arg.input_(0, n, c, wi);
arg.output_(row, column) = ck::type_convert<OutDataType>(v_in);
InDataType v_in = arg.input_(g, n, c, wi);
arg.output_(g, row, column) = ck::type_convert<OutDataType>(v_in);
}
column++;
}
}
};
make_ParallelTensorFunctor(func, N, Wo)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(func, G, N, Wo)(std::thread::hardware_concurrency());
return 0;
}
@@ -138,7 +137,7 @@ struct ReferenceImageToColumn : public device::BaseOperator
const index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto n, auto ho, auto wo) {
auto func = [&](auto g, auto n, auto ho, auto wo) {
index_t row = n * Ho * Wo + ho * Wo + wo;
index_t column = 0;
@@ -162,8 +161,9 @@ struct ReferenceImageToColumn : public device::BaseOperator
wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{
InDataType v_in = arg.input_(0, n, c, hi, wi);
arg.output_(row, column) = ck::type_convert<OutDataType>(v_in);
InDataType v_in = arg.input_(g, n, c, hi, wi);
arg.output_(g, row, column) =
ck::type_convert<OutDataType>(v_in);
}
column++;
}
@@ -171,7 +171,7 @@ struct ReferenceImageToColumn : public device::BaseOperator
}
};
make_ParallelTensorFunctor(func, N, Ho, Wo)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(func, G, N, Ho, Wo)(std::thread::hardware_concurrency());
return 0;
}
@@ -181,7 +181,7 @@ struct ReferenceImageToColumn : public device::BaseOperator
const index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto n, auto d_o, auto ho, auto wo) {
auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
index_t column = 0;
@@ -213,8 +213,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
ck::type_convert<std::size_t>(wi) <
arg.input_.GetLengths()[5])
{
InDataType v_in = arg.input_(0, n, c, di, hi, wi);
arg.output_(row, column) =
InDataType v_in = arg.input_(g, n, c, di, hi, wi);
arg.output_(g, row, column) =
ck::type_convert<OutDataType>(v_in);
}
column++;
@@ -224,7 +224,7 @@ struct ReferenceImageToColumn : public device::BaseOperator
}
};
make_ParallelTensorFunctor(func, N, Do, Ho, Wo)(
make_ParallelTensorFunctor(func, G, N, Do, Ho, Wo)(
std::thread::hardware_concurrency());
return 0;
@@ -267,8 +267,9 @@ struct ReferenceImageToColumn : public device::BaseOperator
C * ck::accumulate_n<index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.output_.GetLengths()[0] == static_cast<std::size_t>(NDoHoWo) &&
arg.output_.GetLengths()[1] == static_cast<std::size_t>(CZYX)))
if(!(arg.output_.GetLengths()[0] == static_cast<std::size_t>(G) &&
arg.output_.GetLengths()[1] == static_cast<std::size_t>(NDoHoWo) &&
arg.output_.GetLengths()[2] == static_cast<std::size_t>(CZYX)))
{
return false;
}

View File

@@ -19,109 +19,214 @@ namespace instance {
using namespace ck::conv_tensor_rearrange_op;
// GNWC/GNHWC/GNDHWC
// Image to Column
// nhwc, 1d
void add_device_image_to_column_nwc_1d_bf16_instances(
// GNWC, 1d
void add_device_image_to_column_gnwc_1d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, BF16, BF16, ImageToColumn>>>&
instances);
void add_device_image_to_column_nwc_1d_f16_instances(
void add_device_image_to_column_gnwc_1d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F16, F16, ImageToColumn>>>&
instances);
void add_device_image_to_column_nwc_1d_f32_instances(
void add_device_image_to_column_gnwc_1d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F32, F32, ImageToColumn>>>&
instances);
void add_device_image_to_column_nwc_1d_i8_instances(
void add_device_image_to_column_gnwc_1d_i8_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, int8_t, int8_t, ImageToColumn>>>&
instances);
// nhwc, 2d
void add_device_image_to_column_nhwc_2d_bf16_instances(
// GNHWC, 2d
void add_device_image_to_column_gnhwc_2d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, BF16, BF16, ImageToColumn>>>&
instances);
void add_device_image_to_column_nhwc_2d_f16_instances(
void add_device_image_to_column_gnhwc_2d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F16, F16, ImageToColumn>>>&
instances);
void add_device_image_to_column_nhwc_2d_f32_instances(
void add_device_image_to_column_gnhwc_2d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F32, F32, ImageToColumn>>>&
instances);
void add_device_image_to_column_nhwc_2d_i8_instances(
void add_device_image_to_column_gnhwc_2d_i8_instances(
std::vector<
std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, int8_t, int8_t, ImageToColumn>>>&
instances);
// nhwc, 3d
void add_device_image_to_column_ndhwc_3d_bf16_instances(
// GNDHWC, 3d
void add_device_image_to_column_gndhwc_3d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, BF16, BF16, ImageToColumn>>>&
instances);
void add_device_image_to_column_ndhwc_3d_f16_instances(
void add_device_image_to_column_gndhwc_3d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F16, F16, ImageToColumn>>>&
instances);
void add_device_image_to_column_ndhwc_3d_f32_instances(
void add_device_image_to_column_gndhwc_3d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F32, F32, ImageToColumn>>>&
instances);
void add_device_image_to_column_ndhwc_3d_i8_instances(
void add_device_image_to_column_gndhwc_3d_i8_instances(
std::vector<
std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, int8_t, int8_t, ImageToColumn>>>&
instances);
// Column to Image
// nhwc, 1d
void add_device_column_to_image_nwc_1d_bf16_instances(
// GNWC, 1d
void add_device_column_to_image_gnwc_1d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, BF16, BF16, ColumnToImage>>>&
instances);
void add_device_column_to_image_nwc_1d_f16_instances(
void add_device_column_to_image_gnwc_1d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F16, F16, ColumnToImage>>>&
instances);
void add_device_column_to_image_nwc_1d_f32_instances(
void add_device_column_to_image_gnwc_1d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F32, F32, ColumnToImage>>>&
instances);
void add_device_column_to_image_nwc_1d_i8_instances(
void add_device_column_to_image_gnwc_1d_i8_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, int8_t, int8_t, ColumnToImage>>>&
instances);
// nhwc, 2d
void add_device_column_to_image_nhwc_2d_bf16_instances(
// GNHWC, 2d
void add_device_column_to_image_gnhwc_2d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, BF16, BF16, ColumnToImage>>>&
instances);
void add_device_column_to_image_nhwc_2d_f16_instances(
void add_device_column_to_image_gnhwc_2d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F16, F16, ColumnToImage>>>&
instances);
void add_device_column_to_image_nhwc_2d_f32_instances(
void add_device_column_to_image_gnhwc_2d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F32, F32, ColumnToImage>>>&
instances);
void add_device_column_to_image_nhwc_2d_i8_instances(
void add_device_column_to_image_gnhwc_2d_i8_instances(
std::vector<
std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, int8_t, int8_t, ColumnToImage>>>&
instances);
// nhwc, 3d
void add_device_column_to_image_ndhwc_3d_bf16_instances(
// GNDHWC, 3d
void add_device_column_to_image_gndhwc_3d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, BF16, BF16, ColumnToImage>>>&
instances);
void add_device_column_to_image_ndhwc_3d_f16_instances(
void add_device_column_to_image_gndhwc_3d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F16, F16, ColumnToImage>>>&
instances);
void add_device_column_to_image_ndhwc_3d_f32_instances(
void add_device_column_to_image_gndhwc_3d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F32, F32, ColumnToImage>>>&
instances);
void add_device_column_to_image_ndhwc_3d_i8_instances(
void add_device_column_to_image_gndhwc_3d_i8_instances(
std::vector<
std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, int8_t, int8_t, ColumnToImage>>>&
instances);
// NWGC/NHWGC/NDHWGC
// Image to Column
// NWGC, 1d
void add_device_image_to_column_nwgc_1d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, BF16, BF16, ImageToColumn>>>&
instances);
void add_device_image_to_column_nwgc_1d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, F16, F16, ImageToColumn>>>&
instances);
void add_device_image_to_column_nwgc_1d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, F32, F32, ImageToColumn>>>&
instances);
void add_device_image_to_column_nwgc_1d_i8_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, int8_t, int8_t, ImageToColumn>>>&
instances);
// NHWGC, 2d
void add_device_image_to_column_nhwgc_2d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, BF16, BF16, ImageToColumn>>>&
instances);
void add_device_image_to_column_nhwgc_2d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, F16, F16, ImageToColumn>>>&
instances);
void add_device_image_to_column_nhwgc_2d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, F32, F32, ImageToColumn>>>&
instances);
void add_device_image_to_column_nhwgc_2d_i8_instances(
std::vector<
std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, int8_t, int8_t, ImageToColumn>>>&
instances);
// NDHWGC, 3d
void add_device_image_to_column_ndhwgc_3d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, BF16, BF16, ImageToColumn>>>&
instances);
void add_device_image_to_column_ndhwgc_3d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, F16, F16, ImageToColumn>>>&
instances);
void add_device_image_to_column_ndhwgc_3d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, F32, F32, ImageToColumn>>>&
instances);
void add_device_image_to_column_ndhwgc_3d_i8_instances(
std::vector<
std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, int8_t, int8_t, ImageToColumn>>>&
instances);
// Column to Image
// NWGC, 1d
void add_device_column_to_image_nwgc_1d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, BF16, BF16, ColumnToImage>>>&
instances);
void add_device_column_to_image_nwgc_1d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, F16, F16, ColumnToImage>>>&
instances);
void add_device_column_to_image_nwgc_1d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, F32, F32, ColumnToImage>>>&
instances);
void add_device_column_to_image_nwgc_1d_i8_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, int8_t, int8_t, ColumnToImage>>>&
instances);
// NHWGC, 2d
void add_device_column_to_image_nhwgc_2d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, BF16, BF16, ColumnToImage>>>&
instances);
void add_device_column_to_image_nhwgc_2d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, F16, F16, ColumnToImage>>>&
instances);
void add_device_column_to_image_nhwgc_2d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, F32, F32, ColumnToImage>>>&
instances);
void add_device_column_to_image_nhwgc_2d_i8_instances(
std::vector<
std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, int8_t, int8_t, ColumnToImage>>>&
instances);
// NDHWGC, 3d
void add_device_column_to_image_ndhwgc_3d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, BF16, BF16, ColumnToImage>>>&
instances);
void add_device_column_to_image_ndhwgc_3d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, F16, F16, ColumnToImage>>>&
instances);
void add_device_column_to_image_ndhwgc_3d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, F32, F32, ColumnToImage>>>&
instances);
void add_device_column_to_image_ndhwgc_3d_i8_instances(
std::vector<
std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, int8_t, int8_t, ColumnToImage>>>&
instances);
template <ck::index_t NumDimSpatial,
typename ImageLayout,
@@ -151,60 +256,120 @@ struct DeviceOperationInstanceFactory<
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_image_to_column_nwc_1d_f32_instances(op_ptrs);
add_device_image_to_column_gnwc_1d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_image_to_column_nwc_1d_f16_instances(op_ptrs);
add_device_image_to_column_gnwc_1d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_image_to_column_nwc_1d_bf16_instances(op_ptrs);
add_device_image_to_column_gnwc_1d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_image_to_column_nwc_1d_i8_instances(op_ptrs);
add_device_image_to_column_gnwc_1d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<ImageLayout, GNHWC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_image_to_column_nhwc_2d_f32_instances(op_ptrs);
add_device_image_to_column_gnhwc_2d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_image_to_column_nhwc_2d_f16_instances(op_ptrs);
add_device_image_to_column_gnhwc_2d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_image_to_column_nhwc_2d_bf16_instances(op_ptrs);
add_device_image_to_column_gnhwc_2d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_image_to_column_nhwc_2d_i8_instances(op_ptrs);
add_device_image_to_column_gnhwc_2d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<ImageLayout, GNDHWC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_image_to_column_ndhwc_3d_f32_instances(op_ptrs);
add_device_image_to_column_gndhwc_3d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_image_to_column_ndhwc_3d_f16_instances(op_ptrs);
add_device_image_to_column_gndhwc_3d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_image_to_column_ndhwc_3d_bf16_instances(op_ptrs);
add_device_image_to_column_gndhwc_3d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_image_to_column_ndhwc_3d_i8_instances(op_ptrs);
add_device_image_to_column_gndhwc_3d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 1 && is_same_v<ImageLayout, NWGC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_image_to_column_nwgc_1d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_image_to_column_nwgc_1d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_image_to_column_nwgc_1d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_image_to_column_nwgc_1d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<ImageLayout, NHWGC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_image_to_column_nhwgc_2d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_image_to_column_nhwgc_2d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_image_to_column_nhwgc_2d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_image_to_column_nhwgc_2d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<ImageLayout, NDHWGC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_image_to_column_ndhwgc_3d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_image_to_column_ndhwgc_3d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_image_to_column_ndhwgc_3d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_image_to_column_ndhwgc_3d_i8_instances(op_ptrs);
}
}
}
@@ -214,60 +379,120 @@ struct DeviceOperationInstanceFactory<
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_column_to_image_nwc_1d_f32_instances(op_ptrs);
add_device_column_to_image_gnwc_1d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_column_to_image_nwc_1d_f16_instances(op_ptrs);
add_device_column_to_image_gnwc_1d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_column_to_image_nwc_1d_bf16_instances(op_ptrs);
add_device_column_to_image_gnwc_1d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_column_to_image_nwc_1d_i8_instances(op_ptrs);
add_device_column_to_image_gnwc_1d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<ImageLayout, GNHWC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_column_to_image_nhwc_2d_f32_instances(op_ptrs);
add_device_column_to_image_gnhwc_2d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_column_to_image_nhwc_2d_f16_instances(op_ptrs);
add_device_column_to_image_gnhwc_2d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_column_to_image_nhwc_2d_bf16_instances(op_ptrs);
add_device_column_to_image_gnhwc_2d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_column_to_image_nhwc_2d_i8_instances(op_ptrs);
add_device_column_to_image_gnhwc_2d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<ImageLayout, GNDHWC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_column_to_image_ndhwc_3d_f32_instances(op_ptrs);
add_device_column_to_image_gndhwc_3d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_column_to_image_ndhwc_3d_f16_instances(op_ptrs);
add_device_column_to_image_gndhwc_3d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_column_to_image_ndhwc_3d_bf16_instances(op_ptrs);
add_device_column_to_image_gndhwc_3d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_column_to_image_ndhwc_3d_i8_instances(op_ptrs);
add_device_column_to_image_gndhwc_3d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 1 && is_same_v<ImageLayout, NWGC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_column_to_image_nwgc_1d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_column_to_image_nwgc_1d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_column_to_image_nwgc_1d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_column_to_image_nwgc_1d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<ImageLayout, NHWGC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_column_to_image_nhwgc_2d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_column_to_image_nhwgc_2d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_column_to_image_nhwgc_2d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_column_to_image_nhwgc_2d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<ImageLayout, NDHWGC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_column_to_image_ndhwgc_3d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_column_to_image_ndhwgc_3d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_column_to_image_ndhwgc_3d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_column_to_image_ndhwgc_3d_i8_instances(op_ptrs);
}
}
}