Add support for groups in Img2Col/Col2Img (#1007)

* Add support for groups in Img2Col/Col2Img

* Fix interface test

* Fix interface test G to N

* Improve performance

* Change gemm layout to 3d

* Fixes

[ROCm/composable_kernel commit: 2e824c6d46]
This commit is contained in:
Bartłomiej Kocot
2023-10-31 10:46:32 +01:00
committed by GitHub
parent 1a389f2d2e
commit 60a0f176d3
30 changed files with 1114 additions and 281 deletions

View File

@@ -45,14 +45,20 @@ class TestConvTensorRearrange : public ::testing::Test
using namespace ck::tensor_layout::convolution;
using namespace ck::conv_tensor_rearrange_op;
using KernelTypes1d =
::testing::Types<std::tuple<GNWC, ImageToColumn>, std::tuple<GNWC, ColumnToImage>>;
using KernelTypes1d = ::testing::Types<std::tuple<GNWC, ImageToColumn>,
std::tuple<GNWC, ColumnToImage>,
std::tuple<NWGC, ImageToColumn>,
std::tuple<NWGC, ColumnToImage>>;
using KernelTypes2d =
::testing::Types<std::tuple<GNHWC, ImageToColumn>, std::tuple<GNHWC, ColumnToImage>>;
using KernelTypes2d = ::testing::Types<std::tuple<GNHWC, ImageToColumn>,
std::tuple<GNHWC, ColumnToImage>,
std::tuple<NHWGC, ImageToColumn>,
std::tuple<NHWGC, ColumnToImage>>;
using KernelTypes3d =
::testing::Types<std::tuple<GNDHWC, ImageToColumn>, std::tuple<GNDHWC, ColumnToImage>>;
using KernelTypes3d = ::testing::Types<std::tuple<GNDHWC, ImageToColumn>,
std::tuple<GNDHWC, ColumnToImage>,
std::tuple<NDHWGC, ImageToColumn>,
std::tuple<NDHWGC, ColumnToImage>>;
template <typename Tuple>
class TestConvTensorRearrange1d : public TestConvTensorRearrange<Tuple>
@@ -77,16 +83,16 @@ TYPED_TEST(TestConvTensorRearrange1d, Test1D)
{
this->conv_params.clear();
this->conv_params.push_back({1, 1, 4, 1, 192, {3}, {28}, {1}, {1}, {1}, {1}});
this->conv_params.push_back({1, 1, 64, 1, 64, {3}, {14}, {1}, {1}, {1}, {1}});
this->conv_params.push_back({1, 1, 64, 1, 64, {1}, {7}, {3}, {1}, {0}, {0}});
this->conv_params.push_back({1, 1, 64, 1, 64, {1}, {3}, {1}, {1}, {0}, {0}});
this->conv_params.push_back({1, 2, 4, 1, 192, {3}, {28}, {1}, {1}, {1}, {1}});
this->conv_params.push_back({1, 2, 64, 1, 64, {3}, {14}, {1}, {1}, {1}, {1}});
this->conv_params.push_back({1, 2, 64, 1, 64, {1}, {7}, {3}, {1}, {0}, {0}});
this->conv_params.push_back({1, 2, 64, 1, 64, {1}, {3}, {1}, {1}, {0}, {0}});
// ScalarPerVector should be 1
this->conv_params.push_back({1, 1, 4, 1, 1, {3}, {28}, {1}, {1}, {1}, {1}});
this->conv_params.push_back({1, 2, 4, 1, 1, {3}, {28}, {1}, {1}, {1}, {1}});
// stride != 1
this->conv_params.push_back({1, 1, 1, 1, 4, {3}, {28}, {2}, {1}, {1}, {1}});
this->conv_params.push_back({1, 2, 1, 1, 4, {3}, {28}, {2}, {1}, {1}, {1}});
// dilation != 1
this->conv_params.push_back({1, 1, 1, 1, 4, {3}, {28}, {1}, {2}, {1}, {1}});
this->conv_params.push_back({1, 2, 1, 1, 4, {3}, {28}, {1}, {2}, {1}, {1}});
#ifdef CK_ENABLE_FP32
this->template Run<1, float, float>();
#endif
@@ -106,13 +112,13 @@ TYPED_TEST(TestConvTensorRearrange2d, Test2D)
this->conv_params.clear();
this->conv_params.push_back(
{2, 1, 4, 1, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
{2, 2, 4, 1, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 1, 64, 1, 64, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
{2, 2, 64, 1, 64, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 64, 1, 64, {1, 1}, {7, 7}, {3, 3}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 1, 64, 1, 64, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 1, 64, 1, 64, {3, 3}, {28, 28}, {2, 2}, {2, 2}, {1, 1}, {1, 1}});
{2, 2, 64, 1, 64, {3, 3}, {28, 28}, {2, 2}, {2, 2}, {1, 1}, {1, 1}});
#ifdef CK_ENABLE_FP32
this->template Run<2, float, float>();
#endif
@@ -131,13 +137,13 @@ TYPED_TEST(TestConvTensorRearrange3d, Test3D)
{
this->conv_params.clear();
this->conv_params.push_back(
{3, 1, 16, 1, 64, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {3, 3, 3}, {0, 0, 0}, {0, 0, 0}});
{3, 2, 16, 1, 64, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {3, 3, 3}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 2, 1, 64, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
{3, 2, 2, 1, 64, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 32, 1, 64, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
{3, 2, 32, 1, 64, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 64, 1, 64, {3, 3, 3}, {14, 14, 14}, {2, 2, 2}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}});
{3, 2, 64, 1, 64, {3, 3, 3}, {14, 14, 14}, {2, 2, 2}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}});
#ifdef CK_ENABLE_FP32
this->template Run<3, float, float>();
#endif

View File

@@ -53,7 +53,7 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
template <typename ConvTensorRearrangeOp>
bool Run()
{
const auto G = conv_param.G_;
const auto N = conv_param.N_;
const auto C = conv_param.C_;
const auto FakeC =
@@ -71,13 +71,13 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
const auto image_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>(
conv_param);
const auto gemm_desc = HostTensorDescriptor({NDoHoWo, CZYX});
const auto gemm_desc = HostTensorDescriptor({G, NDoHoWo, CZYX});
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_g_n_c_wis_strides{};
std::array<ck::index_t, 2> output_m_k_strides{};
std::array<ck::index_t, 3> output_g_m_k_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
@@ -89,7 +89,7 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
copy(conv_param.filter_spatial_lengths_, filter_spatial_lengths);
copy(conv_param.output_spatial_lengths_, output_spatial_lengths);
copy(image_desc.GetStrides(), input_g_n_c_wis_strides);
copy(gemm_desc.GetStrides(), output_m_k_strides);
copy(gemm_desc.GetStrides(), output_g_m_k_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
@@ -100,13 +100,14 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
auto img2col = DeviceImgToColInstance{};
auto argument = img2col.MakeArgument(nullptr,
nullptr,
G,
N,
IsCPacked ? C : FakeC,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_g_n_c_wis_strides,
output_m_k_strides,
output_g_m_k_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
@@ -119,13 +120,14 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
auto col2img = DeviceColToimgInstance{};
auto argument = col2img.MakeArgument(nullptr,
nullptr,
G,
N,
IsCPacked ? C : FakeC,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_g_n_c_wis_strides,
output_m_k_strides,
output_g_m_k_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,