Added unit tests for TransformConvBwdWeightToGemm conv groups are merged.

This commit is contained in:
Ville Pietilä
2025-09-05 13:05:25 +00:00
parent 81a617c108
commit 1a2b0dcb44

View File

@@ -24,6 +24,7 @@ struct TestConfig
static constexpr index_t NDim = NDimSpatial;
static constexpr ConvolutionSpecialization ConvSpec = ConvolutionSpecialization::Default;
static constexpr bool SplitN = false;
static constexpr index_t NumberOfGroupsToMerge = NumGroupsToMerge;
using ADataType = float;
using CDataType = float;
@@ -43,6 +44,11 @@ using TestConfig1D_no_merge = TestConfig<1>;
using TestConfig2D_no_merge = TestConfig<2>;
using TestConfig3D_no_merge = TestConfig<3>;
constexpr index_t GroupsToMerge = 2;
using TestConfig1D_merge = TestConfig<1, GroupsToMerge>;
using TestConfig2D_merge = TestConfig<2, GroupsToMerge>;
using TestConfig3D_merge = TestConfig<3, GroupsToMerge>;
// Test class template
template <typename Config>
class TestTransformConvBwdWeightToGemm : public ::testing::Test
@@ -171,7 +177,10 @@ protected:
using TestTypes = ::testing::Types<
TestConfig1D_no_merge,
TestConfig2D_no_merge,
TestConfig3D_no_merge>;
TestConfig3D_no_merge,
TestConfig1D_merge,
TestConfig2D_merge,
TestConfig3D_merge>;
TYPED_TEST_SUITE(TestTransformConvBwdWeightToGemm, TestTypes);
@@ -246,7 +255,6 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, Constructor)
}
}
// Test grid descriptors
TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
{
@@ -256,6 +264,8 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
constexpr auto I2 = number<2>{};
constexpr auto I3 = number<3>{};
constexpr index_t Gm = TypeParam::NumberOfGroupsToMerge;
if constexpr (NDim == 1)
{
typename TypeParam::TransformType transform(this->a_g_n_c_wis_lengths_1d_,
@@ -272,7 +282,7 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
auto wei_grid_desc = transform.template make_wei_grid_desc<1>();
// Verify output grid descriptor dimensions
EXPECT_EQ(out_grid_desc.get_length(I0), this->K_);
EXPECT_EQ(out_grid_desc.get_length(I0), this->K_ * Gm);
EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Wo_);
// Verify input grid descriptor dimensions
@@ -281,8 +291,8 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
EXPECT_EQ(in_grid_desc.get_length(I2), this->C_);
// Verify weight grid descriptor dimensions
EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_);
EXPECT_EQ(wei_grid_desc.get_length(I1), this->X_ * this->C_);
EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_ * Gm);
EXPECT_EQ(wei_grid_desc.get_length(I1), this->X_ * this->C_ * Gm);
}
else if constexpr (NDim == 2)
@@ -301,7 +311,7 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
auto wei_grid_desc = transform.template make_wei_grid_desc<2>();
// Verify output grid descriptor dimensions
EXPECT_EQ(out_grid_desc.get_length(I0), this->K_);
EXPECT_EQ(out_grid_desc.get_length(I0), this->K_ * Gm);
EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Ho_ * this->Wo_);
// Verify input grid descriptor dimensions
@@ -311,8 +321,8 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
EXPECT_EQ(in_grid_desc.get_length(I3), this->C_);
// Verify weight grid descriptor dimensions
EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_);
EXPECT_EQ(wei_grid_desc.get_length(I1), this->Y_ * this->X_ * this->C_);
EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_ * Gm);
EXPECT_EQ(wei_grid_desc.get_length(I1), this->Y_ * this->X_ * this->C_ * Gm);
}
else if constexpr (NDim == 3)
@@ -331,7 +341,7 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
auto wei_grid_desc = transform.template make_wei_grid_desc<3>();
// Verify output grid descriptor dimensions
EXPECT_EQ(out_grid_desc.get_length(I0), this->K_);
EXPECT_EQ(out_grid_desc.get_length(I0), this->K_ * Gm);
EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Do_ * this->Ho_ * this->Wo_);
// Verify input grid descriptor dimensions
@@ -342,14 +352,15 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
EXPECT_EQ(in_grid_desc.get_length(number<4>{}), this->C_);
// Verify weight grid descriptor dimensions
EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_);
EXPECT_EQ(wei_grid_desc.get_length(I1), this->Z_ * this->Y_ * this->X_ * this->C_);
EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_ * Gm);
EXPECT_EQ(wei_grid_desc.get_length(I1), this->Z_ * this->Y_ * this->X_ * this->C_ * Gm);
}
}
// Test ABC grid descriptors
TYPED_TEST(TestTransformConvBwdWeightToGemm, ABCGridDescriptors)
{
constexpr index_t Gm = TypeParam::NumberOfGroupsToMerge;
constexpr index_t NDim = TypeParam::NDim;
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
@@ -372,11 +383,11 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, ABCGridDescriptors)
const auto& wei_desc = abc_descriptors[I2];
// Verify the descriptors are correctly created
EXPECT_EQ(out_desc.get_length(I0), this->K_);
EXPECT_EQ(wei_desc.get_length(I0), this->K_);
EXPECT_EQ(out_desc.get_length(I0), this->K_ * Gm);
EXPECT_EQ(wei_desc.get_length(I0), this->K_ * Gm);
// For input descriptor, verify the transformed dimensions
EXPECT_EQ(in_desc.get_length(I0), this->X_ * this->C_);
EXPECT_EQ(in_desc.get_length(I0), this->X_ * this->C_ * Gm);
EXPECT_EQ(in_desc.get_length(I1), this->N_ * this->Wo_);
}
@@ -397,11 +408,11 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, ABCGridDescriptors)
const auto& wei_desc = abc_descriptors[I2];
// Verify the descriptors are correctly created
EXPECT_EQ(out_desc.get_length(I0), this->K_);
EXPECT_EQ(wei_desc.get_length(I0), this->K_);
EXPECT_EQ(out_desc.get_length(I0), this->K_ * Gm);
EXPECT_EQ(wei_desc.get_length(I0), this->K_ * Gm);
// For input descriptor, verify the transformed dimensions
EXPECT_EQ(in_desc.get_length(I0), this->Y_ * this->X_ * this->C_);
EXPECT_EQ(in_desc.get_length(I0), this->Y_ * this->X_ * this->C_ * Gm);
EXPECT_EQ(in_desc.get_length(I1), this->N_ * this->Ho_ * this->Wo_);
}
@@ -422,11 +433,11 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, ABCGridDescriptors)
const auto& wei_desc = abc_descriptors[I2];
// Verify the descriptors are correctly created
EXPECT_EQ(out_desc.get_length(I0), this->K_);
EXPECT_EQ(wei_desc.get_length(I0), this->K_);
EXPECT_EQ(out_desc.get_length(I0), this->K_ * Gm);
EXPECT_EQ(wei_desc.get_length(I0), this->K_ * Gm);
// For input descriptor, verify the transformed dimensions
EXPECT_EQ(in_desc.get_length(I0), this->Z_ * this->Y_ * this->X_ * this->C_);
EXPECT_EQ(in_desc.get_length(I0), this->Z_ * this->Y_ * this->X_ * this->C_ * Gm);
EXPECT_EQ(in_desc.get_length(I1), this->N_ * this->Do_ * this->Ho_ * this->Wo_);
}
}