diff --git a/test/ck_tile/transform_conv_to_gemm/test_transform_conv_bwd_weight_to_gemm.cpp b/test/ck_tile/transform_conv_to_gemm/test_transform_conv_bwd_weight_to_gemm.cpp index 73afe286d3..1229b4a6fb 100644 --- a/test/ck_tile/transform_conv_to_gemm/test_transform_conv_bwd_weight_to_gemm.cpp +++ b/test/ck_tile/transform_conv_to_gemm/test_transform_conv_bwd_weight_to_gemm.cpp @@ -24,6 +24,7 @@ struct TestConfig static constexpr index_t NDim = NDimSpatial; static constexpr ConvolutionSpecialization ConvSpec = ConvolutionSpecialization::Default; static constexpr bool SplitN = false; + static constexpr index_t NumberOfGroupsToMerge = NumGroupsToMerge; using ADataType = float; using CDataType = float; @@ -43,6 +44,11 @@ using TestConfig1D_no_merge = TestConfig<1>; using TestConfig2D_no_merge = TestConfig<2>; using TestConfig3D_no_merge = TestConfig<3>; +constexpr index_t GroupsToMerge = 2; +using TestConfig1D_merge = TestConfig<1, GroupsToMerge>; +using TestConfig2D_merge = TestConfig<2, GroupsToMerge>; +using TestConfig3D_merge = TestConfig<3, GroupsToMerge>; + // Test class template template class TestTransformConvBwdWeightToGemm : public ::testing::Test @@ -171,7 +177,10 @@ protected: using TestTypes = ::testing::Types< TestConfig1D_no_merge, TestConfig2D_no_merge, - TestConfig3D_no_merge>; + TestConfig3D_no_merge, + TestConfig1D_merge, + TestConfig2D_merge, + TestConfig3D_merge>; TYPED_TEST_SUITE(TestTransformConvBwdWeightToGemm, TestTypes); @@ -246,7 +255,6 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, Constructor) } } - // Test grid descriptors TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors) { @@ -256,6 +264,8 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors) constexpr auto I2 = number<2>{}; constexpr auto I3 = number<3>{}; + constexpr index_t Gm = TypeParam::NumberOfGroupsToMerge; + if constexpr (NDim == 1) { typename TypeParam::TransformType transform(this->a_g_n_c_wis_lengths_1d_, @@ -272,7 +282,7 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors) auto wei_grid_desc = transform.template make_wei_grid_desc<1>(); // Verify output grid descriptor dimensions - EXPECT_EQ(out_grid_desc.get_length(I0), this->K_); + EXPECT_EQ(out_grid_desc.get_length(I0), this->K_ * Gm); EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Wo_); // Verify input grid descriptor dimensions @@ -281,8 +291,8 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors) EXPECT_EQ(in_grid_desc.get_length(I2), this->C_); // Verify weight grid descriptor dimensions - EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_); - EXPECT_EQ(wei_grid_desc.get_length(I1), this->X_ * this->C_); + EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_ * Gm); + EXPECT_EQ(wei_grid_desc.get_length(I1), this->X_ * this->C_ * Gm); } else if constexpr (NDim == 2) @@ -301,7 +311,7 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors) auto wei_grid_desc = transform.template make_wei_grid_desc<2>(); // Verify output grid descriptor dimensions - EXPECT_EQ(out_grid_desc.get_length(I0), this->K_); + EXPECT_EQ(out_grid_desc.get_length(I0), this->K_ * Gm); EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Ho_ * this->Wo_); // Verify input grid descriptor dimensions @@ -311,8 +321,8 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors) EXPECT_EQ(in_grid_desc.get_length(I3), this->C_); // Verify weight grid descriptor dimensions - EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_); - EXPECT_EQ(wei_grid_desc.get_length(I1), this->Y_ * this->X_ * this->C_); + EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_ * Gm); + EXPECT_EQ(wei_grid_desc.get_length(I1), this->Y_ * this->X_ * this->C_ * Gm); } else if constexpr (NDim == 3) @@ -331,7 +341,7 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors) auto wei_grid_desc = transform.template make_wei_grid_desc<3>(); // Verify output grid descriptor dimensions - EXPECT_EQ(out_grid_desc.get_length(I0), this->K_); + EXPECT_EQ(out_grid_desc.get_length(I0), this->K_ * Gm); EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Do_ * this->Ho_ * this->Wo_); // Verify input grid descriptor dimensions @@ -342,14 +352,15 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors) EXPECT_EQ(in_grid_desc.get_length(number<4>{}), this->C_); // Verify weight grid descriptor dimensions - EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_); - EXPECT_EQ(wei_grid_desc.get_length(I1), this->Z_ * this->Y_ * this->X_ * this->C_); + EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_ * Gm); + EXPECT_EQ(wei_grid_desc.get_length(I1), this->Z_ * this->Y_ * this->X_ * this->C_ * Gm); } } // Test ABC grid descriptors TYPED_TEST(TestTransformConvBwdWeightToGemm, ABCGridDescriptors) { + constexpr index_t Gm = TypeParam::NumberOfGroupsToMerge; constexpr index_t NDim = TypeParam::NDim; constexpr auto I0 = number<0>{}; constexpr auto I1 = number<1>{}; @@ -372,11 +383,11 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, ABCGridDescriptors) const auto& wei_desc = abc_descriptors[I2]; // Verify the descriptors are correctly created - EXPECT_EQ(out_desc.get_length(I0), this->K_); - EXPECT_EQ(wei_desc.get_length(I0), this->K_); + EXPECT_EQ(out_desc.get_length(I0), this->K_ * Gm); + EXPECT_EQ(wei_desc.get_length(I0), this->K_ * Gm); // For input descriptor, verify the transformed dimensions - EXPECT_EQ(in_desc.get_length(I0), this->X_ * this->C_); + EXPECT_EQ(in_desc.get_length(I0), this->X_ * this->C_ * Gm); EXPECT_EQ(in_desc.get_length(I1), this->N_ * this->Wo_); } @@ -397,11 +408,11 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, ABCGridDescriptors) const auto& wei_desc = abc_descriptors[I2]; // Verify the descriptors are correctly created - EXPECT_EQ(out_desc.get_length(I0), this->K_); - EXPECT_EQ(wei_desc.get_length(I0), this->K_); + EXPECT_EQ(out_desc.get_length(I0), this->K_ * Gm); + EXPECT_EQ(wei_desc.get_length(I0), this->K_ * Gm); // For input descriptor, verify the transformed dimensions - EXPECT_EQ(in_desc.get_length(I0), this->Y_ * this->X_ * this->C_); + EXPECT_EQ(in_desc.get_length(I0), this->Y_ * this->X_ * this->C_ * Gm); EXPECT_EQ(in_desc.get_length(I1), this->N_ * this->Ho_ * this->Wo_); } @@ -422,11 +433,11 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, ABCGridDescriptors) const auto& wei_desc = abc_descriptors[I2]; // Verify the descriptors are correctly created - EXPECT_EQ(out_desc.get_length(I0), this->K_); - EXPECT_EQ(wei_desc.get_length(I0), this->K_); + EXPECT_EQ(out_desc.get_length(I0), this->K_ * Gm); + EXPECT_EQ(wei_desc.get_length(I0), this->K_ * Gm); // For input descriptor, verify the transformed dimensions - EXPECT_EQ(in_desc.get_length(I0), this->Z_ * this->Y_ * this->X_ * this->C_); + EXPECT_EQ(in_desc.get_length(I0), this->Z_ * this->Y_ * this->X_ * this->C_ * Gm); EXPECT_EQ(in_desc.get_length(I1), this->N_ * this->Do_ * this->Ho_ * this->Wo_); } }