From 8845b2325408d136c34fddcd42e1c5fa16f69260 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Mon, 8 Sep 2025 15:41:54 +0000 Subject: [PATCH] WIP: Tensor transformations. --- .../transform_conv_bwd_weight_to_gemm.hpp | 251 ++++++++++++++---- ...test_transform_conv_bwd_weight_to_gemm.cpp | 142 ++++++---- 2 files changed, 295 insertions(+), 98 deletions(-) diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp index bb1cbcd46e..cbe3d8a54e 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp @@ -410,17 +410,29 @@ struct TransformConvBwdWeightToGemm } #endif + ////////////////// + // 1D + ////////////////// template ::type = false> CK_TILE_HOST auto make_out_grid_desc() const { // NWGK const index_t NDoHoWoStride = G_ * K_; - constexpr auto KStride = I1; + const index_t GStride = K_; + constexpr auto KStride = I1; - // TODO Add support for NumGroupsToMerge > 1 - - return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_), - make_tuple(KStride, NDoHoWoStride)); + if constexpr (NumGroupsToMerge > 1) + { + return make_naive_tensor_descriptor( + make_tuple(K_, NumGroupsToMerge, N_ * Wo_), + make_tuple(GStride, KStride, NDoHoWoStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(K_, N_ * Wo_), + make_tuple(KStride, NDoHoWoStride)); + } } template ::type = false> @@ -429,11 +441,21 @@ struct TransformConvBwdWeightToGemm // NWGC const index_t NStride = Wi_ * G_ * C_; const index_t WiStride = G_ * C_; + const index_t GStride = C_; constexpr auto CStride = I1; - // TODO Add support for NumGroupsToMerge > 1 - return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_), - make_tuple(NStride, WiStride, CStride)); + if constexpr (NumGroupsToMerge > 1) + { + return make_naive_tensor_descriptor( + make_tuple(N_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStride, WiStride, GStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N_, Wi_, C_), + make_tuple(NStride, WiStride, CStride)); + } } template ::type = false> @@ -441,23 +463,44 @@ struct TransformConvBwdWeightToGemm { // GKXC const index_t KStride = X_ * C_; + const index_t GStride = K_ * X_ * C_; constexpr auto CXStride = I1; - // TODO Add support for NumGroupsToMerge > 1 - return make_naive_tensor_descriptor(make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride)); + if constexpr (NumGroupsToMerge > 1) + { + return make_naive_tensor_descriptor( + make_tuple(NumGroupsToMerge, K_, X_ * C_), + make_tuple(GStride, KStride, CXStride)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(K_, X_ * C_), + make_tuple(KStride, CXStride)); + } } + ////////////////// + // 2D + ////////////////// template ::type = false> CK_TILE_HOST auto make_out_grid_desc() const { // NHWGK const index_t NDoHoWoStride = G_ * K_; - constexpr auto KStride = I1; + const index_t GStride = K_; + constexpr auto KStride = I1; - // TODO Add support for NumGroupsToMerge > 1 - - return make_naive_tensor_descriptor(make_tuple(K_, N_ * Ho_ * Wo_), - make_tuple(KStride, NDoHoWoStride)); + if constexpr (NumGroupsToMerge > 1) + { + return make_naive_tensor_descriptor( + make_tuple(K_, NumGroupsToMerge, N_ * Ho_ * Wo_), + make_tuple(KStride, GStride, NDoHoWoStride)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(K_, N_ * Ho_ * Wo_), + make_tuple(KStride, NDoHoWoStride)); + } } template ::type = false> @@ -467,11 +510,21 @@ struct TransformConvBwdWeightToGemm const index_t NStride = Hi_ * Wi_ * G_ * C_; const index_t HiStride = Wi_ * G_ * C_; const index_t WiStride = G_ * C_; + const index_t GStride = C_; constexpr auto CStride = I1; - // TODO Add support for NumGroupsToMerge > 1 - return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), - make_tuple(NStride, HiStride, WiStride, CStride)); + if constexpr (NumGroupsToMerge > 1) + { + return make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStride, HiStride, WiStride, GStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStride, HiStride, WiStride, CStride)); + } } template ::type = false> @@ -479,24 +532,44 @@ struct TransformConvBwdWeightToGemm { // GKYXC const index_t KStride = Y_ * X_ * C_; + const index_t GStride = K_ * Y_ * X_ * C_; constexpr auto CStride = I1; - // TODO Add support for NumGroupsToMerge > 1 - return make_naive_tensor_descriptor(make_tuple(K_, Y_ * X_ * C_), - make_tuple(KStride, CStride)); + if constexpr (NumGroupsToMerge > 1) + { + return make_naive_tensor_descriptor( + make_tuple(NumGroupsToMerge, K_, Y_ * X_ * C_), + make_tuple(GStride, KStride, CStride)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(K_, Y_ * X_ * C_), + make_tuple(KStride, CStride)); + } } + ////////////////// + // 3D + ////////////////// template ::type = false> CK_TILE_HOST auto make_out_grid_desc() const { // NDHWGK const index_t NDoHoWoStride = G_ * K_; - constexpr auto KStride = I1; - - // TODO Add support for NumGroupsToMerge > 1 - - return make_naive_tensor_descriptor(make_tuple(K_, N_ * Do_ * Ho_ * Wo_), - make_tuple(KStride, NDoHoWoStride)); + const index_t GStride = K_; + constexpr auto KStride = I1; + + if constexpr (NumGroupsToMerge > 1) + { + return make_naive_tensor_descriptor( + make_tuple(K_, NumGroupsToMerge, N_ * Do_ * Ho_ * Wo_), + make_tuple(KStride, GStride, NDoHoWoStride)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(K_, N_ * Do_ * Ho_ * Wo_), + make_tuple(KStride, NDoHoWoStride)); + } } template ::type = false> @@ -506,12 +579,21 @@ struct TransformConvBwdWeightToGemm const index_t DiStride = Hi_ * Wi_ * G_ * C_; const index_t HiStride = Wi_ * G_ * C_; const index_t WiStride = G_ * C_; + const index_t GStride = C_; constexpr auto CStride = I1; - // TODO Add support for NumGroupsToMerge > 1 - return make_naive_tensor_descriptor( + if constexpr (NumGroupsToMerge > 1) + { + return make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStride, DiStride, HiStride, WiStride, GStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( make_tuple(N_, Di_, Hi_, Wi_, C_), make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + } } template ::type = false> @@ -519,11 +601,20 @@ struct TransformConvBwdWeightToGemm { // KZYXC const index_t KStride = Z_ * Y_ * X_ * C_; + const index_t GStride = K_ * Z_ * Y_ * X_ * C_; constexpr auto CStride = I1; - // TODO Add support for NumGroupsToMerge > 1 - return make_naive_tensor_descriptor(make_tuple(K_, Z_ * Y_ * X_ * C_), + if constexpr (NumGroupsToMerge > 1) + { + return make_naive_tensor_descriptor( + make_tuple(NumGroupsToMerge, K_, Z_ * Y_ * X_ * C_), + make_tuple(GStride, KStride, CStride)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(K_, Z_ * Y_ * X_ * C_), make_tuple(KStride, CStride)); + } } // TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as @@ -537,31 +628,85 @@ struct TransformConvBwdWeightToGemm const auto wei_grid_desc = make_wei_grid_desc(); // B: input tensor comes in K_N - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + if constexpr (NumGroupsToMerge > 1) + { + // [K, Gm, (N*Wo)] -> [ (X*C), (N*Wo*Gm)] + const auto out_gemm_m_gem_k_total = + transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_merge_transform(make_tuple(K_, NumGroupsToMerge)), + make_pass_through_transform(X_ * C_)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + // [N, Wi, Gm, C] -> [N, (Wi + InLeftPadW + InRightPadW), Gm, C] = [N, Wip, Gm, C] + const auto in_n_wip_gm_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple( - make_pass_through_transform(N_), - make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + // [N, Wip, Gm, C] -> [N, X, Wo, Gm, C] + const auto in_n_x_wo_gm_c_grid_desc = transform_tensor_descriptor( + in_n_wip_gm_c_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform( + make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}, sequence<4>{})); - const auto in_gemmn_gemmktotal_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(X_, C_)), - make_merge_transform(make_tuple(N_, Wo_))), - make_tuple(sequence<1, 3>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + // [N, X, Wo, Gm, C] -> [ (X*C), (N*Wo*Gm) ] + const auto in_gemmn_gemmktotal_grid_desc = + transform_tensor_descriptor( + in_n_x_wo_gm_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(X_, C_)), + make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge))), + make_tuple(sequence<1, 4>{}, sequence<0, 2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc); + return make_tuple(out_gemm_m_gem_k_total, in_gemmn_gemmktotal_grid_desc, wei_grid_desc); + } + else + { + // [N, Wi, C] -> [N, (Wi + InLeftPadW + InRightPadW), C] = [N, Wip, C] + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + // [N, Wip, C] -> [N, X, Wo, C] + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + // [N, X, Wo, C] -> [ (X*C), (N*Wo) ] + const auto in_gemmn_gemmktotal_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(X_, C_)), + make_merge_transform(make_tuple(N_, Wo_))), + make_tuple(sequence<1, 3>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc); + } } template ::type = false> diff --git a/test/ck_tile/transform_conv_to_gemm/test_transform_conv_bwd_weight_to_gemm.cpp b/test/ck_tile/transform_conv_to_gemm/test_transform_conv_bwd_weight_to_gemm.cpp index 1229b4a6fb..4ce6c4ea9b 100644 --- a/test/ck_tile/transform_conv_to_gemm/test_transform_conv_bwd_weight_to_gemm.cpp +++ b/test/ck_tile/transform_conv_to_gemm/test_transform_conv_bwd_weight_to_gemm.cpp @@ -263,6 +263,8 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors) constexpr auto I1 = number<1>{}; constexpr auto I2 = number<2>{}; constexpr auto I3 = number<3>{}; + constexpr auto I4 = number<4>{}; + constexpr auto I5 = number<5>{}; constexpr index_t Gm = TypeParam::NumberOfGroupsToMerge; @@ -281,18 +283,34 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors) auto in_grid_desc = transform.template make_in_grid_desc<1>(); auto wei_grid_desc = transform.template make_wei_grid_desc<1>(); - // Verify output grid descriptor dimensions - EXPECT_EQ(out_grid_desc.get_length(I0), this->K_ * Gm); - EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Wo_); - // Verify input grid descriptor dimensions EXPECT_EQ(in_grid_desc.get_length(I0), this->N_); EXPECT_EQ(in_grid_desc.get_length(I1), this->Wi_); - EXPECT_EQ(in_grid_desc.get_length(I2), this->C_); - // Verify weight grid descriptor dimensions - EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_ * Gm); - EXPECT_EQ(wei_grid_desc.get_length(I1), this->X_ * this->C_ * Gm); + // Verify output grid descriptor dimensions + EXPECT_EQ(out_grid_desc.get_length(I0), this->K_); + + if constexpr (Gm > 1) + { + EXPECT_EQ(in_grid_desc.get_length(I2), Gm); + EXPECT_EQ(in_grid_desc.get_length(I3), this->C_); + + EXPECT_EQ(wei_grid_desc.get_length(I0), Gm); + EXPECT_EQ(wei_grid_desc.get_length(I1), this->K_); + EXPECT_EQ(wei_grid_desc.get_length(I2), this->X_ * this->C_); + + EXPECT_EQ(out_grid_desc.get_length(I1), Gm); + EXPECT_EQ(out_grid_desc.get_length(I2), this->N_ * this->Wo_); + } + else + { + EXPECT_EQ(in_grid_desc.get_length(I2), this->C_); + + EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_); + EXPECT_EQ(wei_grid_desc.get_length(I1), this->X_ * this->C_); + + EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Wo_); + } } else if constexpr (NDim == 2) @@ -310,19 +328,35 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors) auto in_grid_desc = transform.template make_in_grid_desc<2>(); auto wei_grid_desc = transform.template make_wei_grid_desc<2>(); - // Verify output grid descriptor dimensions - EXPECT_EQ(out_grid_desc.get_length(I0), this->K_ * Gm); - EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Ho_ * this->Wo_); - // Verify input grid descriptor dimensions EXPECT_EQ(in_grid_desc.get_length(I0), this->N_); EXPECT_EQ(in_grid_desc.get_length(I1), this->Hi_); EXPECT_EQ(in_grid_desc.get_length(I2), this->Wi_); - EXPECT_EQ(in_grid_desc.get_length(I3), this->C_); - - // Verify weight grid descriptor dimensions - EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_ * Gm); - EXPECT_EQ(wei_grid_desc.get_length(I1), this->Y_ * this->X_ * this->C_ * Gm); + + // Verify output grid descriptor dimensions + EXPECT_EQ(out_grid_desc.get_length(I0), this->K_); + + if constexpr (Gm > 1) + { + EXPECT_EQ(in_grid_desc.get_length(I3), Gm); + EXPECT_EQ(in_grid_desc.get_length(I4), this->C_); + + EXPECT_EQ(wei_grid_desc.get_length(I0), Gm); + EXPECT_EQ(wei_grid_desc.get_length(I1), this->K_); + EXPECT_EQ(wei_grid_desc.get_length(I2), this->Y_ * this->X_ * this->C_); + + EXPECT_EQ(out_grid_desc.get_length(I1), Gm); + EXPECT_EQ(out_grid_desc.get_length(I2), this->N_ * this->Ho_ * this->Wo_); + } + else + { + EXPECT_EQ(in_grid_desc.get_length(I3), this->C_); + + EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_); + EXPECT_EQ(wei_grid_desc.get_length(I1), this->Y_ * this->X_ * this->C_); + + EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Ho_ * this->Wo_); + } } else if constexpr (NDim == 3) @@ -340,20 +374,36 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors) auto in_grid_desc = transform.template make_in_grid_desc<3>(); auto wei_grid_desc = transform.template make_wei_grid_desc<3>(); - // Verify output grid descriptor dimensions - EXPECT_EQ(out_grid_desc.get_length(I0), this->K_ * Gm); - EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Do_ * this->Ho_ * this->Wo_); - // Verify input grid descriptor dimensions EXPECT_EQ(in_grid_desc.get_length(I0), this->N_); EXPECT_EQ(in_grid_desc.get_length(I1), this->Di_); EXPECT_EQ(in_grid_desc.get_length(I2), this->Hi_); EXPECT_EQ(in_grid_desc.get_length(I3), this->Wi_); - EXPECT_EQ(in_grid_desc.get_length(number<4>{}), this->C_); - // Verify weight grid descriptor dimensions - EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_ * Gm); - EXPECT_EQ(wei_grid_desc.get_length(I1), this->Z_ * this->Y_ * this->X_ * this->C_ * Gm); + // Verify output grid descriptor dimensions + EXPECT_EQ(out_grid_desc.get_length(I0), this->K_); + + if constexpr (Gm > 1) + { + EXPECT_EQ(in_grid_desc.get_length(I4), Gm); + EXPECT_EQ(in_grid_desc.get_length(I5), this->C_); + + EXPECT_EQ(wei_grid_desc.get_length(I0), Gm); + EXPECT_EQ(wei_grid_desc.get_length(I1), this->K_); + EXPECT_EQ(wei_grid_desc.get_length(I2), this->Z_ * this->Y_ * this->X_ * this->C_); + + EXPECT_EQ(out_grid_desc.get_length(I1), Gm); + EXPECT_EQ(out_grid_desc.get_length(I2), this->N_ * this->Do_ * this->Ho_ * this->Wo_); + } + else + { + EXPECT_EQ(in_grid_desc.get_length(I4), this->C_); + + EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_); + EXPECT_EQ(wei_grid_desc.get_length(I1), this->Z_ * this->Y_ * this->X_ * this->C_); + + EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Do_ * this->Ho_ * this->Wo_); + } } } @@ -378,18 +428,24 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, ABCGridDescriptors) // Test combined ABC grid descriptors const auto abc_descriptors = transform.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(); - const auto& out_desc = abc_descriptors[I0]; - const auto& in_desc = abc_descriptors[I1]; - const auto& wei_desc = abc_descriptors[I2]; + const auto& out_desc = abc_descriptors[I0]; // GEMM A () + const auto& in_desc = abc_descriptors[I1]; // GEMM B (N_gemm, K_gemm) + const auto& wei_desc = abc_descriptors[I2]; // GEMM C + + EXPECT_EQ(out_desc.get_num_of_dimension(), 3); + EXPECT_EQ(in_desc.get_num_of_dimension(), 2); + EXPECT_EQ(wei_desc.get_num_of_dimension(), 3); - // Verify the descriptors are correctly created - EXPECT_EQ(out_desc.get_length(I0), this->K_ * Gm); - EXPECT_EQ(wei_desc.get_length(I0), this->K_ * Gm); - - // For input descriptor, verify the transformed dimensions - EXPECT_EQ(in_desc.get_length(I0), this->X_ * this->C_ * Gm); - EXPECT_EQ(in_desc.get_length(I1), this->N_ * this->Wo_); + EXPECT_EQ(in_desc.get_length(I0), this->X_ * this->C_); + EXPECT_EQ(in_desc.get_length(I1), this->N_ * this->Wo_ * Gm); + + // // Verify GEMM M-dimension that should depend on Gm + // EXPECT_EQ(out_desc.get_length(I1), this->K_ * Gm); + // EXPECT_EQ(wei_desc.get_length(I0), this->K_ * Gm); + // // Verify GEMM N-dimension that should depend on Gm. + // EXPECT_EQ(in_desc.get_length(I1), this->X_ * this->C_ * Gm); + // EXPECT_EQ(wei_desc.get_length(I1), this->X_ * this->C_ * Gm); } else if constexpr (NDim == 2) { @@ -407,13 +463,11 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, ABCGridDescriptors) const auto& in_desc = abc_descriptors[I1]; const auto& wei_desc = abc_descriptors[I2]; - // Verify the descriptors are correctly created - EXPECT_EQ(out_desc.get_length(I0), this->K_ * Gm); + EXPECT_EQ(out_desc.get_length(I1), this->K_ * Gm); EXPECT_EQ(wei_desc.get_length(I0), this->K_ * Gm); - // For input descriptor, verify the transformed dimensions - EXPECT_EQ(in_desc.get_length(I0), this->Y_ * this->X_ * this->C_ * Gm); - EXPECT_EQ(in_desc.get_length(I1), this->N_ * this->Ho_ * this->Wo_); + EXPECT_EQ(in_desc.get_length(I1), this->Y_ * this->X_ * this->C_ * Gm); + EXPECT_EQ(wei_desc.get_length(I1), this->Y_ *this->X_ * this->C_ * Gm); } else if constexpr (NDim == 3) @@ -432,13 +486,11 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, ABCGridDescriptors) const auto& in_desc = abc_descriptors[I1]; const auto& wei_desc = abc_descriptors[I2]; - // Verify the descriptors are correctly created - EXPECT_EQ(out_desc.get_length(I0), this->K_ * Gm); + EXPECT_EQ(out_desc.get_length(I1), this->K_ * Gm); EXPECT_EQ(wei_desc.get_length(I0), this->K_ * Gm); - // For input descriptor, verify the transformed dimensions - EXPECT_EQ(in_desc.get_length(I0), this->Z_ * this->Y_ * this->X_ * this->C_ * Gm); - EXPECT_EQ(in_desc.get_length(I1), this->N_ * this->Do_ * this->Ho_ * this->Wo_); + EXPECT_EQ(in_desc.get_length(I1), this->Z_ * this->Y_ * this->X_ * this->C_ * Gm); + EXPECT_EQ(wei_desc.get_length(I1), this->Z_ * this->Y_ *this->X_ * this->C_ * Gm); } }