mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
Added unit tests for TransformConvBwdWeightToGemm conv groups are merged.
This commit is contained in:
@@ -24,6 +24,7 @@ struct TestConfig
|
||||
static constexpr index_t NDim = NDimSpatial;
|
||||
static constexpr ConvolutionSpecialization ConvSpec = ConvolutionSpecialization::Default;
|
||||
static constexpr bool SplitN = false;
|
||||
static constexpr index_t NumberOfGroupsToMerge = NumGroupsToMerge;
|
||||
|
||||
using ADataType = float;
|
||||
using CDataType = float;
|
||||
@@ -43,6 +44,11 @@ using TestConfig1D_no_merge = TestConfig<1>;
|
||||
using TestConfig2D_no_merge = TestConfig<2>;
|
||||
using TestConfig3D_no_merge = TestConfig<3>;
|
||||
|
||||
constexpr index_t GroupsToMerge = 2;
|
||||
using TestConfig1D_merge = TestConfig<1, GroupsToMerge>;
|
||||
using TestConfig2D_merge = TestConfig<2, GroupsToMerge>;
|
||||
using TestConfig3D_merge = TestConfig<3, GroupsToMerge>;
|
||||
|
||||
// Test class template
|
||||
template <typename Config>
|
||||
class TestTransformConvBwdWeightToGemm : public ::testing::Test
|
||||
@@ -171,7 +177,10 @@ protected:
|
||||
using TestTypes = ::testing::Types<
|
||||
TestConfig1D_no_merge,
|
||||
TestConfig2D_no_merge,
|
||||
TestConfig3D_no_merge>;
|
||||
TestConfig3D_no_merge,
|
||||
TestConfig1D_merge,
|
||||
TestConfig2D_merge,
|
||||
TestConfig3D_merge>;
|
||||
|
||||
TYPED_TEST_SUITE(TestTransformConvBwdWeightToGemm, TestTypes);
|
||||
|
||||
@@ -246,7 +255,6 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, Constructor)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Test grid descriptors
|
||||
TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
|
||||
{
|
||||
@@ -256,6 +264,8 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
|
||||
constexpr auto I2 = number<2>{};
|
||||
constexpr auto I3 = number<3>{};
|
||||
|
||||
constexpr index_t Gm = TypeParam::NumberOfGroupsToMerge;
|
||||
|
||||
if constexpr (NDim == 1)
|
||||
{
|
||||
typename TypeParam::TransformType transform(this->a_g_n_c_wis_lengths_1d_,
|
||||
@@ -272,7 +282,7 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
|
||||
auto wei_grid_desc = transform.template make_wei_grid_desc<1>();
|
||||
|
||||
// Verify output grid descriptor dimensions
|
||||
EXPECT_EQ(out_grid_desc.get_length(I0), this->K_);
|
||||
EXPECT_EQ(out_grid_desc.get_length(I0), this->K_ * Gm);
|
||||
EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Wo_);
|
||||
|
||||
// Verify input grid descriptor dimensions
|
||||
@@ -281,8 +291,8 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
|
||||
EXPECT_EQ(in_grid_desc.get_length(I2), this->C_);
|
||||
|
||||
// Verify weight grid descriptor dimensions
|
||||
EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_);
|
||||
EXPECT_EQ(wei_grid_desc.get_length(I1), this->X_ * this->C_);
|
||||
EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_ * Gm);
|
||||
EXPECT_EQ(wei_grid_desc.get_length(I1), this->X_ * this->C_ * Gm);
|
||||
|
||||
}
|
||||
else if constexpr (NDim == 2)
|
||||
@@ -301,7 +311,7 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
|
||||
auto wei_grid_desc = transform.template make_wei_grid_desc<2>();
|
||||
|
||||
// Verify output grid descriptor dimensions
|
||||
EXPECT_EQ(out_grid_desc.get_length(I0), this->K_);
|
||||
EXPECT_EQ(out_grid_desc.get_length(I0), this->K_ * Gm);
|
||||
EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Ho_ * this->Wo_);
|
||||
|
||||
// Verify input grid descriptor dimensions
|
||||
@@ -311,8 +321,8 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
|
||||
EXPECT_EQ(in_grid_desc.get_length(I3), this->C_);
|
||||
|
||||
// Verify weight grid descriptor dimensions
|
||||
EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_);
|
||||
EXPECT_EQ(wei_grid_desc.get_length(I1), this->Y_ * this->X_ * this->C_);
|
||||
EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_ * Gm);
|
||||
EXPECT_EQ(wei_grid_desc.get_length(I1), this->Y_ * this->X_ * this->C_ * Gm);
|
||||
|
||||
}
|
||||
else if constexpr (NDim == 3)
|
||||
@@ -331,7 +341,7 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
|
||||
auto wei_grid_desc = transform.template make_wei_grid_desc<3>();
|
||||
|
||||
// Verify output grid descriptor dimensions
|
||||
EXPECT_EQ(out_grid_desc.get_length(I0), this->K_);
|
||||
EXPECT_EQ(out_grid_desc.get_length(I0), this->K_ * Gm);
|
||||
EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Do_ * this->Ho_ * this->Wo_);
|
||||
|
||||
// Verify input grid descriptor dimensions
|
||||
@@ -342,14 +352,15 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
|
||||
EXPECT_EQ(in_grid_desc.get_length(number<4>{}), this->C_);
|
||||
|
||||
// Verify weight grid descriptor dimensions
|
||||
EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_);
|
||||
EXPECT_EQ(wei_grid_desc.get_length(I1), this->Z_ * this->Y_ * this->X_ * this->C_);
|
||||
EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_ * Gm);
|
||||
EXPECT_EQ(wei_grid_desc.get_length(I1), this->Z_ * this->Y_ * this->X_ * this->C_ * Gm);
|
||||
}
|
||||
}
|
||||
|
||||
// Test ABC grid descriptors
|
||||
TYPED_TEST(TestTransformConvBwdWeightToGemm, ABCGridDescriptors)
|
||||
{
|
||||
constexpr index_t Gm = TypeParam::NumberOfGroupsToMerge;
|
||||
constexpr index_t NDim = TypeParam::NDim;
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
@@ -372,11 +383,11 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, ABCGridDescriptors)
|
||||
const auto& wei_desc = abc_descriptors[I2];
|
||||
|
||||
// Verify the descriptors are correctly created
|
||||
EXPECT_EQ(out_desc.get_length(I0), this->K_);
|
||||
EXPECT_EQ(wei_desc.get_length(I0), this->K_);
|
||||
EXPECT_EQ(out_desc.get_length(I0), this->K_ * Gm);
|
||||
EXPECT_EQ(wei_desc.get_length(I0), this->K_ * Gm);
|
||||
|
||||
// For input descriptor, verify the transformed dimensions
|
||||
EXPECT_EQ(in_desc.get_length(I0), this->X_ * this->C_);
|
||||
EXPECT_EQ(in_desc.get_length(I0), this->X_ * this->C_ * Gm);
|
||||
EXPECT_EQ(in_desc.get_length(I1), this->N_ * this->Wo_);
|
||||
|
||||
}
|
||||
@@ -397,11 +408,11 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, ABCGridDescriptors)
|
||||
const auto& wei_desc = abc_descriptors[I2];
|
||||
|
||||
// Verify the descriptors are correctly created
|
||||
EXPECT_EQ(out_desc.get_length(I0), this->K_);
|
||||
EXPECT_EQ(wei_desc.get_length(I0), this->K_);
|
||||
EXPECT_EQ(out_desc.get_length(I0), this->K_ * Gm);
|
||||
EXPECT_EQ(wei_desc.get_length(I0), this->K_ * Gm);
|
||||
|
||||
// For input descriptor, verify the transformed dimensions
|
||||
EXPECT_EQ(in_desc.get_length(I0), this->Y_ * this->X_ * this->C_);
|
||||
EXPECT_EQ(in_desc.get_length(I0), this->Y_ * this->X_ * this->C_ * Gm);
|
||||
EXPECT_EQ(in_desc.get_length(I1), this->N_ * this->Ho_ * this->Wo_);
|
||||
|
||||
}
|
||||
@@ -422,11 +433,11 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, ABCGridDescriptors)
|
||||
const auto& wei_desc = abc_descriptors[I2];
|
||||
|
||||
// Verify the descriptors are correctly created
|
||||
EXPECT_EQ(out_desc.get_length(I0), this->K_);
|
||||
EXPECT_EQ(wei_desc.get_length(I0), this->K_);
|
||||
EXPECT_EQ(out_desc.get_length(I0), this->K_ * Gm);
|
||||
EXPECT_EQ(wei_desc.get_length(I0), this->K_ * Gm);
|
||||
|
||||
// For input descriptor, verify the transformed dimensions
|
||||
EXPECT_EQ(in_desc.get_length(I0), this->Z_ * this->Y_ * this->X_ * this->C_);
|
||||
EXPECT_EQ(in_desc.get_length(I0), this->Z_ * this->Y_ * this->X_ * this->C_ * Gm);
|
||||
EXPECT_EQ(in_desc.get_length(I1), this->N_ * this->Do_ * this->Ho_ * this->Wo_);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user