mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 05:19:20 +00:00
Add support for groups in Img2Col/Col2Img (#1007)
* Add support for groups in Img2Col/Col2Img
* Fix interface test
* Fix interface test G to N
* Improve performance
* Change gemm layout to 3d
* Fixes
[ROCm/composable_kernel commit: 2e824c6d46]
This commit is contained in:
@@ -19,9 +19,7 @@ namespace host {
|
||||
* \brief Reference implementation for column to image.
|
||||
*
|
||||
* Input tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
|
||||
* Memory layout is the same.
|
||||
* Output tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
|
||||
* G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C].
|
||||
*
|
||||
* \tparam NDimSpatial Number of spatial dimensions.
|
||||
* \tparam ImageLayout Image Layout.
|
||||
@@ -95,18 +93,19 @@ struct ReferenceColumnToImage : public device::BaseOperator
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
if(!(arg.output_.GetNumOfDimension() == NDimSpatial + 3 &&
|
||||
arg.input_.GetNumOfDimension() == 2))
|
||||
arg.input_.GetNumOfDimension() == 3))
|
||||
{
|
||||
throw std::runtime_error("wrong! inconsistent dimension");
|
||||
}
|
||||
|
||||
const index_t G = arg.output_.GetLengths()[0];
|
||||
const index_t N = arg.output_.GetLengths()[1];
|
||||
const index_t C = arg.output_.GetLengths()[2];
|
||||
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
const index_t Wo = arg.output_spatial_lengths_[0];
|
||||
auto func = [&](auto n) {
|
||||
auto func = [&](auto g, auto n) {
|
||||
for(index_t wo = 0; wo < Wo; ++wo)
|
||||
{
|
||||
index_t row = n * Wo + wo;
|
||||
@@ -123,9 +122,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
|
||||
if(wi >= 0 &&
|
||||
ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3])
|
||||
{
|
||||
float v_in = ck::type_convert<float>(arg.input_(row, column));
|
||||
float v_out = ck::type_convert<float>(arg.output_(0, n, c, wi));
|
||||
arg.output_(0, n, c, wi) =
|
||||
float v_in =
|
||||
ck::type_convert<float>(arg.input_(g, row, column));
|
||||
float v_out = ck::type_convert<float>(arg.output_(g, n, c, wi));
|
||||
arg.output_(g, n, c, wi) =
|
||||
ck::type_convert<OutDataType>(v_in + v_out);
|
||||
}
|
||||
column++;
|
||||
@@ -134,7 +134,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -143,7 +143,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
|
||||
const index_t Ho = arg.output_spatial_lengths_[0];
|
||||
const index_t Wo = arg.output_spatial_lengths_[1];
|
||||
|
||||
auto func = [&](auto n) {
|
||||
auto func = [&](auto g, auto n) {
|
||||
for(index_t ho = 0; ho < Ho; ++ho)
|
||||
{
|
||||
for(index_t wo = 0; wo < Wo; ++wo)
|
||||
@@ -176,10 +176,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
|
||||
arg.output_.GetLengths()[4])
|
||||
{
|
||||
float v_in =
|
||||
ck::type_convert<float>(arg.input_(row, column));
|
||||
ck::type_convert<float>(arg.input_(g, row, column));
|
||||
float v_out = ck::type_convert<float>(
|
||||
arg.output_(0, n, c, hi, wi));
|
||||
arg.output_(0, n, c, hi, wi) =
|
||||
arg.output_(g, n, c, hi, wi));
|
||||
arg.output_(g, n, c, hi, wi) =
|
||||
ck::type_convert<OutDataType>(v_in + v_out);
|
||||
}
|
||||
column++;
|
||||
@@ -190,7 +190,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -200,7 +200,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
|
||||
const index_t Ho = arg.output_spatial_lengths_[1];
|
||||
const index_t Wo = arg.output_spatial_lengths_[2];
|
||||
|
||||
auto func = [&](auto n) {
|
||||
auto func = [&](auto g, auto n) {
|
||||
for(index_t d_o = 0; d_o < Do; ++d_o)
|
||||
{
|
||||
for(index_t ho = 0; ho < Ho; ++ho)
|
||||
@@ -245,10 +245,10 @@ struct ReferenceColumnToImage : public device::BaseOperator
|
||||
arg.output_.GetLengths()[5])
|
||||
{
|
||||
float v_in = ck::type_convert<float>(
|
||||
arg.input_(row, column));
|
||||
arg.input_(g, row, column));
|
||||
float v_out = ck::type_convert<float>(
|
||||
arg.output_(0, n, c, di, hi, wi));
|
||||
arg.output_(0, n, c, di, hi, wi) =
|
||||
arg.output_(g, n, c, di, hi, wi));
|
||||
arg.output_(g, n, c, di, hi, wi) =
|
||||
ck::type_convert<OutDataType>(v_in + v_out);
|
||||
}
|
||||
column++;
|
||||
@@ -261,7 +261,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -303,8 +303,9 @@ struct ReferenceColumnToImage : public device::BaseOperator
|
||||
C * ck::accumulate_n<index_t>(
|
||||
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(NDoHoWo) &&
|
||||
arg.input_.GetLengths()[1] == static_cast<std::size_t>(CZYX)))
|
||||
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(G) &&
|
||||
arg.input_.GetLengths()[1] == static_cast<std::size_t>(NDoHoWo) &&
|
||||
arg.input_.GetLengths()[2] == static_cast<std::size_t>(CZYX)))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -19,9 +19,7 @@ namespace host {
|
||||
* \brief Reference implementation for image to column.
|
||||
*
|
||||
* Input tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
|
||||
* G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C].
|
||||
* Output tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
|
||||
* Memory layout is the same.
|
||||
* Output tensor descriptor has [G * N * Do * Ho * Wo, Z * Y * X * C] data layout.
|
||||
*
|
||||
* \tparam NDimSpatial Number of spatial dimensions.
|
||||
* \tparam ImageLayout Image Layout.
|
||||
@@ -95,18 +93,19 @@ struct ReferenceImageToColumn : public device::BaseOperator
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
|
||||
arg.output_.GetNumOfDimension() == 2))
|
||||
arg.output_.GetNumOfDimension() == 3))
|
||||
{
|
||||
throw std::runtime_error("wrong! inconsistent dimension");
|
||||
}
|
||||
|
||||
const index_t G = arg.input_.GetLengths()[0];
|
||||
const index_t N = arg.input_.GetLengths()[1];
|
||||
const index_t C = arg.input_.GetLengths()[2];
|
||||
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
const index_t Wo = arg.output_spatial_lengths_[0];
|
||||
auto func = [&](auto n, auto wo) {
|
||||
auto func = [&](auto g, auto n, auto wo) {
|
||||
index_t row = n * Wo + wo;
|
||||
index_t column = 0;
|
||||
|
||||
@@ -121,15 +120,15 @@ struct ReferenceImageToColumn : public device::BaseOperator
|
||||
if(wi >= 0 &&
|
||||
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
|
||||
{
|
||||
InDataType v_in = arg.input_(0, n, c, wi);
|
||||
arg.output_(row, column) = ck::type_convert<OutDataType>(v_in);
|
||||
InDataType v_in = arg.input_(g, n, c, wi);
|
||||
arg.output_(g, row, column) = ck::type_convert<OutDataType>(v_in);
|
||||
}
|
||||
column++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func, N, Wo)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(func, G, N, Wo)(std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -138,7 +137,7 @@ struct ReferenceImageToColumn : public device::BaseOperator
|
||||
const index_t Ho = arg.output_spatial_lengths_[0];
|
||||
const index_t Wo = arg.output_spatial_lengths_[1];
|
||||
|
||||
auto func = [&](auto n, auto ho, auto wo) {
|
||||
auto func = [&](auto g, auto n, auto ho, auto wo) {
|
||||
index_t row = n * Ho * Wo + ho * Wo + wo;
|
||||
index_t column = 0;
|
||||
|
||||
@@ -162,8 +161,9 @@ struct ReferenceImageToColumn : public device::BaseOperator
|
||||
wi >= 0 &&
|
||||
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
|
||||
{
|
||||
InDataType v_in = arg.input_(0, n, c, hi, wi);
|
||||
arg.output_(row, column) = ck::type_convert<OutDataType>(v_in);
|
||||
InDataType v_in = arg.input_(g, n, c, hi, wi);
|
||||
arg.output_(g, row, column) =
|
||||
ck::type_convert<OutDataType>(v_in);
|
||||
}
|
||||
column++;
|
||||
}
|
||||
@@ -171,7 +171,7 @@ struct ReferenceImageToColumn : public device::BaseOperator
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func, N, Ho, Wo)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(func, G, N, Ho, Wo)(std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -181,7 +181,7 @@ struct ReferenceImageToColumn : public device::BaseOperator
|
||||
const index_t Ho = arg.output_spatial_lengths_[1];
|
||||
const index_t Wo = arg.output_spatial_lengths_[2];
|
||||
|
||||
auto func = [&](auto n, auto d_o, auto ho, auto wo) {
|
||||
auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
|
||||
index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
|
||||
index_t column = 0;
|
||||
|
||||
@@ -213,8 +213,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
|
||||
ck::type_convert<std::size_t>(wi) <
|
||||
arg.input_.GetLengths()[5])
|
||||
{
|
||||
InDataType v_in = arg.input_(0, n, c, di, hi, wi);
|
||||
arg.output_(row, column) =
|
||||
InDataType v_in = arg.input_(g, n, c, di, hi, wi);
|
||||
arg.output_(g, row, column) =
|
||||
ck::type_convert<OutDataType>(v_in);
|
||||
}
|
||||
column++;
|
||||
@@ -224,7 +224,7 @@ struct ReferenceImageToColumn : public device::BaseOperator
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func, N, Do, Ho, Wo)(
|
||||
make_ParallelTensorFunctor(func, G, N, Do, Ho, Wo)(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
@@ -267,8 +267,9 @@ struct ReferenceImageToColumn : public device::BaseOperator
|
||||
C * ck::accumulate_n<index_t>(
|
||||
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
if(!(arg.output_.GetLengths()[0] == static_cast<std::size_t>(NDoHoWo) &&
|
||||
arg.output_.GetLengths()[1] == static_cast<std::size_t>(CZYX)))
|
||||
if(!(arg.output_.GetLengths()[0] == static_cast<std::size_t>(G) &&
|
||||
arg.output_.GetLengths()[1] == static_cast<std::size_t>(NDoHoWo) &&
|
||||
arg.output_.GetLengths()[2] == static_cast<std::size_t>(CZYX)))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -19,109 +19,214 @@ namespace instance {
|
||||
|
||||
using namespace ck::conv_tensor_rearrange_op;
|
||||
|
||||
// GNWC/GNHWC/GNDHWC
|
||||
// Image to Column
|
||||
// nhwc, 1d
|
||||
void add_device_image_to_column_nwc_1d_bf16_instances(
|
||||
// GNWC, 1d
|
||||
void add_device_image_to_column_gnwc_1d_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, BF16, BF16, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
void add_device_image_to_column_nwc_1d_f16_instances(
|
||||
void add_device_image_to_column_gnwc_1d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F16, F16, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
void add_device_image_to_column_nwc_1d_f32_instances(
|
||||
void add_device_image_to_column_gnwc_1d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F32, F32, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
void add_device_image_to_column_nwc_1d_i8_instances(
|
||||
void add_device_image_to_column_gnwc_1d_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, int8_t, int8_t, ImageToColumn>>>&
|
||||
instances);
|
||||
// nhwc, 2d
|
||||
void add_device_image_to_column_nhwc_2d_bf16_instances(
|
||||
// GNHWC, 2d
|
||||
void add_device_image_to_column_gnhwc_2d_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, BF16, BF16, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
void add_device_image_to_column_nhwc_2d_f16_instances(
|
||||
void add_device_image_to_column_gnhwc_2d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F16, F16, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
void add_device_image_to_column_nhwc_2d_f32_instances(
|
||||
void add_device_image_to_column_gnhwc_2d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F32, F32, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
void add_device_image_to_column_nhwc_2d_i8_instances(
|
||||
void add_device_image_to_column_gnhwc_2d_i8_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, int8_t, int8_t, ImageToColumn>>>&
|
||||
instances);
|
||||
// nhwc, 3d
|
||||
void add_device_image_to_column_ndhwc_3d_bf16_instances(
|
||||
// GNDHWC, 3d
|
||||
void add_device_image_to_column_gndhwc_3d_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, BF16, BF16, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
void add_device_image_to_column_ndhwc_3d_f16_instances(
|
||||
void add_device_image_to_column_gndhwc_3d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F16, F16, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
void add_device_image_to_column_ndhwc_3d_f32_instances(
|
||||
void add_device_image_to_column_gndhwc_3d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F32, F32, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
void add_device_image_to_column_ndhwc_3d_i8_instances(
|
||||
void add_device_image_to_column_gndhwc_3d_i8_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, int8_t, int8_t, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
// Column to Image
|
||||
// nhwc, 1d
|
||||
void add_device_column_to_image_nwc_1d_bf16_instances(
|
||||
// GNWC, 1d
|
||||
void add_device_column_to_image_gnwc_1d_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, BF16, BF16, ColumnToImage>>>&
|
||||
instances);
|
||||
|
||||
void add_device_column_to_image_nwc_1d_f16_instances(
|
||||
void add_device_column_to_image_gnwc_1d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F16, F16, ColumnToImage>>>&
|
||||
instances);
|
||||
|
||||
void add_device_column_to_image_nwc_1d_f32_instances(
|
||||
void add_device_column_to_image_gnwc_1d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F32, F32, ColumnToImage>>>&
|
||||
instances);
|
||||
|
||||
void add_device_column_to_image_nwc_1d_i8_instances(
|
||||
void add_device_column_to_image_gnwc_1d_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, int8_t, int8_t, ColumnToImage>>>&
|
||||
instances);
|
||||
// nhwc, 2d
|
||||
void add_device_column_to_image_nhwc_2d_bf16_instances(
|
||||
// GNHWC, 2d
|
||||
void add_device_column_to_image_gnhwc_2d_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, BF16, BF16, ColumnToImage>>>&
|
||||
instances);
|
||||
|
||||
void add_device_column_to_image_nhwc_2d_f16_instances(
|
||||
void add_device_column_to_image_gnhwc_2d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F16, F16, ColumnToImage>>>&
|
||||
instances);
|
||||
|
||||
void add_device_column_to_image_nhwc_2d_f32_instances(
|
||||
void add_device_column_to_image_gnhwc_2d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F32, F32, ColumnToImage>>>&
|
||||
instances);
|
||||
|
||||
void add_device_column_to_image_nhwc_2d_i8_instances(
|
||||
void add_device_column_to_image_gnhwc_2d_i8_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, int8_t, int8_t, ColumnToImage>>>&
|
||||
instances);
|
||||
// nhwc, 3d
|
||||
void add_device_column_to_image_ndhwc_3d_bf16_instances(
|
||||
// GNDHWC, 3d
|
||||
void add_device_column_to_image_gndhwc_3d_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, BF16, BF16, ColumnToImage>>>&
|
||||
instances);
|
||||
|
||||
void add_device_column_to_image_ndhwc_3d_f16_instances(
|
||||
void add_device_column_to_image_gndhwc_3d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F16, F16, ColumnToImage>>>&
|
||||
instances);
|
||||
|
||||
void add_device_column_to_image_ndhwc_3d_f32_instances(
|
||||
void add_device_column_to_image_gndhwc_3d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F32, F32, ColumnToImage>>>&
|
||||
instances);
|
||||
|
||||
void add_device_column_to_image_ndhwc_3d_i8_instances(
|
||||
void add_device_column_to_image_gndhwc_3d_i8_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, int8_t, int8_t, ColumnToImage>>>&
|
||||
instances);
|
||||
// NWGC/NHWGC/NDHWGC
|
||||
// Image to Column
|
||||
// NWGC, 1d
|
||||
void add_device_image_to_column_nwgc_1d_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, BF16, BF16, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
void add_device_image_to_column_nwgc_1d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, F16, F16, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
void add_device_image_to_column_nwgc_1d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, F32, F32, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
void add_device_image_to_column_nwgc_1d_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, int8_t, int8_t, ImageToColumn>>>&
|
||||
instances);
|
||||
// NHWGC, 2d
|
||||
void add_device_image_to_column_nhwgc_2d_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, BF16, BF16, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
void add_device_image_to_column_nhwgc_2d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, F16, F16, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
void add_device_image_to_column_nhwgc_2d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, F32, F32, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
void add_device_image_to_column_nhwgc_2d_i8_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, int8_t, int8_t, ImageToColumn>>>&
|
||||
instances);
|
||||
// NDHWGC, 3d
|
||||
void add_device_image_to_column_ndhwgc_3d_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, BF16, BF16, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
void add_device_image_to_column_ndhwgc_3d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, F16, F16, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
void add_device_image_to_column_ndhwgc_3d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, F32, F32, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
void add_device_image_to_column_ndhwgc_3d_i8_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, int8_t, int8_t, ImageToColumn>>>&
|
||||
instances);
|
||||
|
||||
// Column to Image
|
||||
// NWGC, 1d
|
||||
void add_device_column_to_image_nwgc_1d_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, BF16, BF16, ColumnToImage>>>&
|
||||
instances);
|
||||
|
||||
void add_device_column_to_image_nwgc_1d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, F16, F16, ColumnToImage>>>&
|
||||
instances);
|
||||
|
||||
void add_device_column_to_image_nwgc_1d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, F32, F32, ColumnToImage>>>&
|
||||
instances);
|
||||
|
||||
void add_device_column_to_image_nwgc_1d_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, NWGC, int8_t, int8_t, ColumnToImage>>>&
|
||||
instances);
|
||||
// NHWGC, 2d
|
||||
void add_device_column_to_image_nhwgc_2d_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, BF16, BF16, ColumnToImage>>>&
|
||||
instances);
|
||||
|
||||
void add_device_column_to_image_nhwgc_2d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, F16, F16, ColumnToImage>>>&
|
||||
instances);
|
||||
|
||||
void add_device_column_to_image_nhwgc_2d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, F32, F32, ColumnToImage>>>&
|
||||
instances);
|
||||
|
||||
void add_device_column_to_image_nhwgc_2d_i8_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceConvTensorRearrange<2, NHWGC, int8_t, int8_t, ColumnToImage>>>&
|
||||
instances);
|
||||
// NDHWGC, 3d
|
||||
void add_device_column_to_image_ndhwgc_3d_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, BF16, BF16, ColumnToImage>>>&
|
||||
instances);
|
||||
|
||||
void add_device_column_to_image_ndhwgc_3d_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, F16, F16, ColumnToImage>>>&
|
||||
instances);
|
||||
|
||||
void add_device_column_to_image_ndhwgc_3d_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, F32, F32, ColumnToImage>>>&
|
||||
instances);
|
||||
|
||||
void add_device_column_to_image_ndhwgc_3d_i8_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceConvTensorRearrange<3, NDHWGC, int8_t, int8_t, ColumnToImage>>>&
|
||||
instances);
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename ImageLayout,
|
||||
@@ -151,60 +256,120 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_image_to_column_nwc_1d_f32_instances(op_ptrs);
|
||||
add_device_image_to_column_gnwc_1d_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_image_to_column_nwc_1d_f16_instances(op_ptrs);
|
||||
add_device_image_to_column_gnwc_1d_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_image_to_column_nwc_1d_bf16_instances(op_ptrs);
|
||||
add_device_image_to_column_gnwc_1d_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_image_to_column_nwc_1d_i8_instances(op_ptrs);
|
||||
add_device_image_to_column_gnwc_1d_i8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2 && is_same_v<ImageLayout, GNHWC>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_image_to_column_nhwc_2d_f32_instances(op_ptrs);
|
||||
add_device_image_to_column_gnhwc_2d_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_image_to_column_nhwc_2d_f16_instances(op_ptrs);
|
||||
add_device_image_to_column_gnhwc_2d_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_image_to_column_nhwc_2d_bf16_instances(op_ptrs);
|
||||
add_device_image_to_column_gnhwc_2d_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_image_to_column_nhwc_2d_i8_instances(op_ptrs);
|
||||
add_device_image_to_column_gnhwc_2d_i8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3 && is_same_v<ImageLayout, GNDHWC>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_image_to_column_ndhwc_3d_f32_instances(op_ptrs);
|
||||
add_device_image_to_column_gndhwc_3d_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_image_to_column_ndhwc_3d_f16_instances(op_ptrs);
|
||||
add_device_image_to_column_gndhwc_3d_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_image_to_column_ndhwc_3d_bf16_instances(op_ptrs);
|
||||
add_device_image_to_column_gndhwc_3d_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_image_to_column_ndhwc_3d_i8_instances(op_ptrs);
|
||||
add_device_image_to_column_gndhwc_3d_i8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 1 && is_same_v<ImageLayout, NWGC>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_image_to_column_nwgc_1d_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_image_to_column_nwgc_1d_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_image_to_column_nwgc_1d_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_image_to_column_nwgc_1d_i8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2 && is_same_v<ImageLayout, NHWGC>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_image_to_column_nhwgc_2d_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_image_to_column_nhwgc_2d_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_image_to_column_nhwgc_2d_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_image_to_column_nhwgc_2d_i8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3 && is_same_v<ImageLayout, NDHWGC>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_image_to_column_ndhwgc_3d_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_image_to_column_ndhwgc_3d_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_image_to_column_ndhwgc_3d_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_image_to_column_ndhwgc_3d_i8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -214,60 +379,120 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_column_to_image_nwc_1d_f32_instances(op_ptrs);
|
||||
add_device_column_to_image_gnwc_1d_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_column_to_image_nwc_1d_f16_instances(op_ptrs);
|
||||
add_device_column_to_image_gnwc_1d_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_column_to_image_nwc_1d_bf16_instances(op_ptrs);
|
||||
add_device_column_to_image_gnwc_1d_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_column_to_image_nwc_1d_i8_instances(op_ptrs);
|
||||
add_device_column_to_image_gnwc_1d_i8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2 && is_same_v<ImageLayout, GNHWC>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_column_to_image_nhwc_2d_f32_instances(op_ptrs);
|
||||
add_device_column_to_image_gnhwc_2d_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_column_to_image_nhwc_2d_f16_instances(op_ptrs);
|
||||
add_device_column_to_image_gnhwc_2d_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_column_to_image_nhwc_2d_bf16_instances(op_ptrs);
|
||||
add_device_column_to_image_gnhwc_2d_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_column_to_image_nhwc_2d_i8_instances(op_ptrs);
|
||||
add_device_column_to_image_gnhwc_2d_i8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3 && is_same_v<ImageLayout, GNDHWC>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_column_to_image_ndhwc_3d_f32_instances(op_ptrs);
|
||||
add_device_column_to_image_gndhwc_3d_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_column_to_image_ndhwc_3d_f16_instances(op_ptrs);
|
||||
add_device_column_to_image_gndhwc_3d_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_column_to_image_ndhwc_3d_bf16_instances(op_ptrs);
|
||||
add_device_column_to_image_gndhwc_3d_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_column_to_image_ndhwc_3d_i8_instances(op_ptrs);
|
||||
add_device_column_to_image_gndhwc_3d_i8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 1 && is_same_v<ImageLayout, NWGC>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_column_to_image_nwgc_1d_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_column_to_image_nwgc_1d_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_column_to_image_nwgc_1d_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_column_to_image_nwgc_1d_i8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2 && is_same_v<ImageLayout, NHWGC>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_column_to_image_nhwgc_2d_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_column_to_image_nhwgc_2d_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_column_to_image_nhwgc_2d_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_column_to_image_nhwgc_2d_i8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3 && is_same_v<ImageLayout, NDHWGC>)
|
||||
{
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_column_to_image_ndhwgc_3d_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_column_to_image_ndhwgc_3d_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_column_to_image_ndhwgc_3d_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_column_to_image_ndhwgc_3d_i8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user