mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
Add support for groups in Img2Col/Col2Img (#1007)
* Add support for groups in Img2Col/Col2Img
* Fix interface test
* Fix interface test G to N
* Improve performance
* Change gemm layout to 3d
* Fixes
[ROCm/composable_kernel commit: 2e824c6d46]
This commit is contained in:
@@ -45,14 +45,20 @@ class TestConvTensorRearrange : public ::testing::Test
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
using namespace ck::conv_tensor_rearrange_op;
|
||||
|
||||
using KernelTypes1d =
|
||||
::testing::Types<std::tuple<GNWC, ImageToColumn>, std::tuple<GNWC, ColumnToImage>>;
|
||||
using KernelTypes1d = ::testing::Types<std::tuple<GNWC, ImageToColumn>,
|
||||
std::tuple<GNWC, ColumnToImage>,
|
||||
std::tuple<NWGC, ImageToColumn>,
|
||||
std::tuple<NWGC, ColumnToImage>>;
|
||||
|
||||
using KernelTypes2d =
|
||||
::testing::Types<std::tuple<GNHWC, ImageToColumn>, std::tuple<GNHWC, ColumnToImage>>;
|
||||
using KernelTypes2d = ::testing::Types<std::tuple<GNHWC, ImageToColumn>,
|
||||
std::tuple<GNHWC, ColumnToImage>,
|
||||
std::tuple<NHWGC, ImageToColumn>,
|
||||
std::tuple<NHWGC, ColumnToImage>>;
|
||||
|
||||
using KernelTypes3d =
|
||||
::testing::Types<std::tuple<GNDHWC, ImageToColumn>, std::tuple<GNDHWC, ColumnToImage>>;
|
||||
using KernelTypes3d = ::testing::Types<std::tuple<GNDHWC, ImageToColumn>,
|
||||
std::tuple<GNDHWC, ColumnToImage>,
|
||||
std::tuple<NDHWGC, ImageToColumn>,
|
||||
std::tuple<NDHWGC, ColumnToImage>>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestConvTensorRearrange1d : public TestConvTensorRearrange<Tuple>
|
||||
@@ -77,16 +83,16 @@ TYPED_TEST(TestConvTensorRearrange1d, Test1D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
|
||||
this->conv_params.push_back({1, 1, 4, 1, 192, {3}, {28}, {1}, {1}, {1}, {1}});
|
||||
this->conv_params.push_back({1, 1, 64, 1, 64, {3}, {14}, {1}, {1}, {1}, {1}});
|
||||
this->conv_params.push_back({1, 1, 64, 1, 64, {1}, {7}, {3}, {1}, {0}, {0}});
|
||||
this->conv_params.push_back({1, 1, 64, 1, 64, {1}, {3}, {1}, {1}, {0}, {0}});
|
||||
this->conv_params.push_back({1, 2, 4, 1, 192, {3}, {28}, {1}, {1}, {1}, {1}});
|
||||
this->conv_params.push_back({1, 2, 64, 1, 64, {3}, {14}, {1}, {1}, {1}, {1}});
|
||||
this->conv_params.push_back({1, 2, 64, 1, 64, {1}, {7}, {3}, {1}, {0}, {0}});
|
||||
this->conv_params.push_back({1, 2, 64, 1, 64, {1}, {3}, {1}, {1}, {0}, {0}});
|
||||
// ScalarPerVector should be 1
|
||||
this->conv_params.push_back({1, 1, 4, 1, 1, {3}, {28}, {1}, {1}, {1}, {1}});
|
||||
this->conv_params.push_back({1, 2, 4, 1, 1, {3}, {28}, {1}, {1}, {1}, {1}});
|
||||
// stride != 1
|
||||
this->conv_params.push_back({1, 1, 1, 1, 4, {3}, {28}, {2}, {1}, {1}, {1}});
|
||||
this->conv_params.push_back({1, 2, 1, 1, 4, {3}, {28}, {2}, {1}, {1}, {1}});
|
||||
// dilation != 1
|
||||
this->conv_params.push_back({1, 1, 1, 1, 4, {3}, {28}, {1}, {2}, {1}, {1}});
|
||||
this->conv_params.push_back({1, 2, 1, 1, 4, {3}, {28}, {1}, {2}, {1}, {1}});
|
||||
#ifdef CK_ENABLE_FP32
|
||||
this->template Run<1, float, float>();
|
||||
#endif
|
||||
@@ -106,13 +112,13 @@ TYPED_TEST(TestConvTensorRearrange2d, Test2D)
|
||||
this->conv_params.clear();
|
||||
|
||||
this->conv_params.push_back(
|
||||
{2, 1, 4, 1, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
{2, 2, 4, 1, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{2, 1, 64, 1, 64, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
{2, 2, 64, 1, 64, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back({2, 1, 64, 1, 64, {1, 1}, {7, 7}, {3, 3}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back({2, 1, 64, 1, 64, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{2, 1, 64, 1, 64, {3, 3}, {28, 28}, {2, 2}, {2, 2}, {1, 1}, {1, 1}});
|
||||
{2, 2, 64, 1, 64, {3, 3}, {28, 28}, {2, 2}, {2, 2}, {1, 1}, {1, 1}});
|
||||
#ifdef CK_ENABLE_FP32
|
||||
this->template Run<2, float, float>();
|
||||
#endif
|
||||
@@ -131,13 +137,13 @@ TYPED_TEST(TestConvTensorRearrange3d, Test3D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 16, 1, 64, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {3, 3, 3}, {0, 0, 0}, {0, 0, 0}});
|
||||
{3, 2, 16, 1, 64, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {3, 3, 3}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 2, 1, 64, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
{3, 2, 2, 1, 64, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 32, 1, 64, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
{3, 2, 32, 1, 64, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 64, 1, 64, {3, 3, 3}, {14, 14, 14}, {2, 2, 2}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}});
|
||||
{3, 2, 64, 1, 64, {3, 3, 3}, {14, 14, 14}, {2, 2, 2}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}});
|
||||
#ifdef CK_ENABLE_FP32
|
||||
this->template Run<3, float, float>();
|
||||
#endif
|
||||
|
||||
@@ -53,7 +53,7 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
|
||||
template <typename ConvTensorRearrangeOp>
|
||||
bool Run()
|
||||
{
|
||||
|
||||
const auto G = conv_param.G_;
|
||||
const auto N = conv_param.N_;
|
||||
const auto C = conv_param.C_;
|
||||
const auto FakeC =
|
||||
@@ -71,13 +71,13 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
|
||||
const auto image_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>(
|
||||
conv_param);
|
||||
const auto gemm_desc = HostTensorDescriptor({NDoHoWo, CZYX});
|
||||
const auto gemm_desc = HostTensorDescriptor({G, NDoHoWo, CZYX});
|
||||
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> input_g_n_c_wis_strides{};
|
||||
std::array<ck::index_t, 2> output_m_k_strides{};
|
||||
std::array<ck::index_t, 3> output_g_m_k_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads{};
|
||||
@@ -89,7 +89,7 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
|
||||
copy(conv_param.filter_spatial_lengths_, filter_spatial_lengths);
|
||||
copy(conv_param.output_spatial_lengths_, output_spatial_lengths);
|
||||
copy(image_desc.GetStrides(), input_g_n_c_wis_strides);
|
||||
copy(gemm_desc.GetStrides(), output_m_k_strides);
|
||||
copy(gemm_desc.GetStrides(), output_g_m_k_strides);
|
||||
copy(conv_param.conv_filter_strides_, conv_filter_strides);
|
||||
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
|
||||
copy(conv_param.input_left_pads_, input_left_pads);
|
||||
@@ -100,13 +100,14 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
|
||||
auto img2col = DeviceImgToColInstance{};
|
||||
auto argument = img2col.MakeArgument(nullptr,
|
||||
nullptr,
|
||||
G,
|
||||
N,
|
||||
IsCPacked ? C : FakeC,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_g_n_c_wis_strides,
|
||||
output_m_k_strides,
|
||||
output_g_m_k_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -119,13 +120,14 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
|
||||
auto col2img = DeviceColToimgInstance{};
|
||||
auto argument = col2img.MakeArgument(nullptr,
|
||||
nullptr,
|
||||
G,
|
||||
N,
|
||||
IsCPacked ? C : FakeC,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_g_n_c_wis_strides,
|
||||
output_m_k_strides,
|
||||
output_g_m_k_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
|
||||
Reference in New Issue
Block a user