WIP: Tensor transformations.

This commit is contained in:
Ville Pietilä
2025-09-08 15:41:54 +00:00
parent 1a2b0dcb44
commit 8845b23254
2 changed files with 295 additions and 98 deletions

View File

@@ -410,17 +410,29 @@ struct TransformConvBwdWeightToGemm
}
#endif
//////////////////
// 1D
//////////////////
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
CK_TILE_HOST auto make_out_grid_desc() const
{
// NWGK
const index_t NDoHoWoStride = G_ * K_;
constexpr auto KStride = I1;
const index_t GStride = K_;
constexpr auto KStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
make_tuple(KStride, NDoHoWoStride));
if constexpr (NumGroupsToMerge > 1)
{
return make_naive_tensor_descriptor(
make_tuple(K_, NumGroupsToMerge, N_ * Wo_),
make_tuple(GStride, KStride, NDoHoWoStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(K_, N_ * Wo_),
make_tuple(KStride, NDoHoWoStride));
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
@@ -429,11 +441,21 @@ struct TransformConvBwdWeightToGemm
// NWGC
const index_t NStride = Wi_ * G_ * C_;
const index_t WiStride = G_ * C_;
const index_t GStride = C_;
constexpr auto CStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_),
make_tuple(NStride, WiStride, CStride));
if constexpr (NumGroupsToMerge > 1)
{
return make_naive_tensor_descriptor(
make_tuple(N_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStride, WiStride, GStride, CStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N_, Wi_, C_),
make_tuple(NStride, WiStride, CStride));
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
@@ -441,23 +463,44 @@ struct TransformConvBwdWeightToGemm
{
// GKXC
const index_t KStride = X_ * C_;
const index_t GStride = K_ * X_ * C_;
constexpr auto CXStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride));
if constexpr (NumGroupsToMerge > 1)
{
return make_naive_tensor_descriptor(
make_tuple(NumGroupsToMerge, K_, X_ * C_),
make_tuple(GStride, KStride, CXStride));
}
else
{
return make_naive_tensor_descriptor(make_tuple(K_, X_ * C_),
make_tuple(KStride, CXStride));
}
}
//////////////////
// 2D
//////////////////
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
CK_TILE_HOST auto make_out_grid_desc() const
{
// NHWGK
const index_t NDoHoWoStride = G_ * K_;
constexpr auto KStride = I1;
const index_t GStride = K_;
constexpr auto KStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Ho_ * Wo_),
make_tuple(KStride, NDoHoWoStride));
if constexpr (NumGroupsToMerge > 1)
{
return make_naive_tensor_descriptor(
make_tuple(K_, NumGroupsToMerge, N_ * Ho_ * Wo_),
make_tuple(KStride, GStride, NDoHoWoStride));
}
else
{
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Ho_ * Wo_),
make_tuple(KStride, NDoHoWoStride));
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
@@ -467,11 +510,21 @@ struct TransformConvBwdWeightToGemm
const index_t NStride = Hi_ * Wi_ * G_ * C_;
const index_t HiStride = Wi_ * G_ * C_;
const index_t WiStride = G_ * C_;
const index_t GStride = C_;
constexpr auto CStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_),
make_tuple(NStride, HiStride, WiStride, CStride));
if constexpr (NumGroupsToMerge > 1)
{
return make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStride, HiStride, WiStride, GStride, CStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, C_),
make_tuple(NStride, HiStride, WiStride, CStride));
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
@@ -479,24 +532,44 @@ struct TransformConvBwdWeightToGemm
{
// GKYXC
const index_t KStride = Y_ * X_ * C_;
const index_t GStride = K_ * Y_ * X_ * C_;
constexpr auto CStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, Y_ * X_ * C_),
make_tuple(KStride, CStride));
if constexpr (NumGroupsToMerge > 1)
{
return make_naive_tensor_descriptor(
make_tuple(NumGroupsToMerge, K_, Y_ * X_ * C_),
make_tuple(GStride, KStride, CStride));
}
else
{
return make_naive_tensor_descriptor(make_tuple(K_, Y_ * X_ * C_),
make_tuple(KStride, CStride));
}
}
//////////////////
// 3D
//////////////////
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
CK_TILE_HOST auto make_out_grid_desc() const
{
// NDHWGK
const index_t NDoHoWoStride = G_ * K_;
constexpr auto KStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Do_ * Ho_ * Wo_),
make_tuple(KStride, NDoHoWoStride));
const index_t GStride = K_;
constexpr auto KStride = I1;
if constexpr (NumGroupsToMerge > 1)
{
return make_naive_tensor_descriptor(
make_tuple(K_, NumGroupsToMerge, N_ * Do_ * Ho_ * Wo_),
make_tuple(KStride, GStride, NDoHoWoStride));
}
else
{
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Do_ * Ho_ * Wo_),
make_tuple(KStride, NDoHoWoStride));
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
@@ -506,12 +579,21 @@ struct TransformConvBwdWeightToGemm
const index_t DiStride = Hi_ * Wi_ * G_ * C_;
const index_t HiStride = Wi_ * G_ * C_;
const index_t WiStride = G_ * C_;
const index_t GStride = C_;
constexpr auto CStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(
if constexpr (NumGroupsToMerge > 1)
{
return make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStride, DiStride, HiStride, WiStride, GStride, CStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
@@ -519,11 +601,20 @@ struct TransformConvBwdWeightToGemm
{
// KZYXC
const index_t KStride = Z_ * Y_ * X_ * C_;
const index_t GStride = K_ * Z_ * Y_ * X_ * C_;
constexpr auto CStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, Z_ * Y_ * X_ * C_),
if constexpr (NumGroupsToMerge > 1)
{
return make_naive_tensor_descriptor(
make_tuple(NumGroupsToMerge, K_, Z_ * Y_ * X_ * C_),
make_tuple(GStride, KStride, CStride));
}
else
{
return make_naive_tensor_descriptor(make_tuple(K_, Z_ * Y_ * X_ * C_),
make_tuple(KStride, CStride));
}
}
// TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
@@ -537,31 +628,85 @@ struct TransformConvBwdWeightToGemm
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
// B: input tensor comes in K_N
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
if constexpr (NumGroupsToMerge > 1)
{
// [K, Gm, (N*Wo)] -> [ (X*C), (N*Wo*Gm)]
const auto out_gemm_m_gem_k_total =
transform_tensor_descriptor(
out_grid_desc,
make_tuple(
make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
make_pass_through_transform(X_ * C_)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// [N, Wi, Gm, C] -> [N, (Wi + InLeftPadW + InRightPadW), Gm, C] = [N, Wip, Gm, C]
const auto in_n_wip_gm_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(
make_pass_through_transform(N_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N_),
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
// [N, Wip, Gm, C] -> [N, X, Wo, Gm, C]
const auto in_n_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
in_n_wip_gm_c_grid_desc,
make_tuple(
make_pass_through_transform(N_),
make_embed_transform(
make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}, sequence<4>{}));
const auto in_gemmn_gemmktotal_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(X_, C_)),
make_merge_transform(make_tuple(N_, Wo_))),
make_tuple(sequence<1, 3>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// [N, X, Wo, Gm, C] -> [ (X*C), (N*Wo*Gm) ]
const auto in_gemmn_gemmktotal_grid_desc =
transform_tensor_descriptor(
in_n_x_wo_gm_c_grid_desc,
make_tuple(
make_merge_transform(make_tuple(X_, C_)),
make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge))),
make_tuple(sequence<1, 4>{}, sequence<0, 2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
return make_tuple(out_gemm_m_gem_k_total, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
}
else
{
// [N, Wi, C] -> [N, (Wi + InLeftPadW + InRightPadW), C] = [N, Wip, C]
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
// [N, Wip, C] -> [N, X, Wo, C]
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N_),
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
// [N, X, Wo, C] -> [ (X*C), (N*Wo) ]
const auto in_gemmn_gemmktotal_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(X_, C_)),
make_merge_transform(make_tuple(N_, Wo_))),
make_tuple(sequence<1, 3>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>

View File

@@ -263,6 +263,8 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
constexpr auto I1 = number<1>{};
constexpr auto I2 = number<2>{};
constexpr auto I3 = number<3>{};
constexpr auto I4 = number<4>{};
constexpr auto I5 = number<5>{};
constexpr index_t Gm = TypeParam::NumberOfGroupsToMerge;
@@ -281,18 +283,34 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
auto in_grid_desc = transform.template make_in_grid_desc<1>();
auto wei_grid_desc = transform.template make_wei_grid_desc<1>();
// Verify output grid descriptor dimensions
EXPECT_EQ(out_grid_desc.get_length(I0), this->K_ * Gm);
EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Wo_);
// Verify input grid descriptor dimensions
EXPECT_EQ(in_grid_desc.get_length(I0), this->N_);
EXPECT_EQ(in_grid_desc.get_length(I1), this->Wi_);
EXPECT_EQ(in_grid_desc.get_length(I2), this->C_);
// Verify weight grid descriptor dimensions
EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_ * Gm);
EXPECT_EQ(wei_grid_desc.get_length(I1), this->X_ * this->C_ * Gm);
// Verify output grid descriptor dimensions
EXPECT_EQ(out_grid_desc.get_length(I0), this->K_);
if constexpr (Gm > 1)
{
EXPECT_EQ(in_grid_desc.get_length(I2), Gm);
EXPECT_EQ(in_grid_desc.get_length(I3), this->C_);
EXPECT_EQ(wei_grid_desc.get_length(I0), Gm);
EXPECT_EQ(wei_grid_desc.get_length(I1), this->K_);
EXPECT_EQ(wei_grid_desc.get_length(I2), this->X_ * this->C_);
EXPECT_EQ(out_grid_desc.get_length(I1), Gm);
EXPECT_EQ(out_grid_desc.get_length(I2), this->N_ * this->Wo_);
}
else
{
EXPECT_EQ(in_grid_desc.get_length(I2), this->C_);
EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_);
EXPECT_EQ(wei_grid_desc.get_length(I1), this->X_ * this->C_);
EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Wo_);
}
}
else if constexpr (NDim == 2)
@@ -310,19 +328,35 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
auto in_grid_desc = transform.template make_in_grid_desc<2>();
auto wei_grid_desc = transform.template make_wei_grid_desc<2>();
// Verify output grid descriptor dimensions
EXPECT_EQ(out_grid_desc.get_length(I0), this->K_ * Gm);
EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Ho_ * this->Wo_);
// Verify input grid descriptor dimensions
EXPECT_EQ(in_grid_desc.get_length(I0), this->N_);
EXPECT_EQ(in_grid_desc.get_length(I1), this->Hi_);
EXPECT_EQ(in_grid_desc.get_length(I2), this->Wi_);
EXPECT_EQ(in_grid_desc.get_length(I3), this->C_);
// Verify weight grid descriptor dimensions
EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_ * Gm);
EXPECT_EQ(wei_grid_desc.get_length(I1), this->Y_ * this->X_ * this->C_ * Gm);
// Verify output grid descriptor dimensions
EXPECT_EQ(out_grid_desc.get_length(I0), this->K_);
if constexpr (Gm > 1)
{
EXPECT_EQ(in_grid_desc.get_length(I3), Gm);
EXPECT_EQ(in_grid_desc.get_length(I4), this->C_);
EXPECT_EQ(wei_grid_desc.get_length(I0), Gm);
EXPECT_EQ(wei_grid_desc.get_length(I1), this->K_);
EXPECT_EQ(wei_grid_desc.get_length(I2), this->Y_ * this->X_ * this->C_);
EXPECT_EQ(out_grid_desc.get_length(I1), Gm);
EXPECT_EQ(out_grid_desc.get_length(I2), this->N_ * this->Ho_ * this->Wo_);
}
else
{
EXPECT_EQ(in_grid_desc.get_length(I3), this->C_);
EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_);
EXPECT_EQ(wei_grid_desc.get_length(I1), this->Y_ * this->X_ * this->C_);
EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Ho_ * this->Wo_);
}
}
else if constexpr (NDim == 3)
@@ -340,20 +374,36 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, GridDescriptors)
auto in_grid_desc = transform.template make_in_grid_desc<3>();
auto wei_grid_desc = transform.template make_wei_grid_desc<3>();
// Verify output grid descriptor dimensions
EXPECT_EQ(out_grid_desc.get_length(I0), this->K_ * Gm);
EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Do_ * this->Ho_ * this->Wo_);
// Verify input grid descriptor dimensions
EXPECT_EQ(in_grid_desc.get_length(I0), this->N_);
EXPECT_EQ(in_grid_desc.get_length(I1), this->Di_);
EXPECT_EQ(in_grid_desc.get_length(I2), this->Hi_);
EXPECT_EQ(in_grid_desc.get_length(I3), this->Wi_);
EXPECT_EQ(in_grid_desc.get_length(number<4>{}), this->C_);
// Verify weight grid descriptor dimensions
EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_ * Gm);
EXPECT_EQ(wei_grid_desc.get_length(I1), this->Z_ * this->Y_ * this->X_ * this->C_ * Gm);
// Verify output grid descriptor dimensions
EXPECT_EQ(out_grid_desc.get_length(I0), this->K_);
if constexpr (Gm > 1)
{
EXPECT_EQ(in_grid_desc.get_length(I4), Gm);
EXPECT_EQ(in_grid_desc.get_length(I5), this->C_);
EXPECT_EQ(wei_grid_desc.get_length(I0), Gm);
EXPECT_EQ(wei_grid_desc.get_length(I1), this->K_);
EXPECT_EQ(wei_grid_desc.get_length(I2), this->Z_ * this->Y_ * this->X_ * this->C_);
EXPECT_EQ(out_grid_desc.get_length(I1), Gm);
EXPECT_EQ(out_grid_desc.get_length(I2), this->N_ * this->Do_ * this->Ho_ * this->Wo_);
}
else
{
EXPECT_EQ(in_grid_desc.get_length(I4), this->C_);
EXPECT_EQ(wei_grid_desc.get_length(I0), this->K_);
EXPECT_EQ(wei_grid_desc.get_length(I1), this->Z_ * this->Y_ * this->X_ * this->C_);
EXPECT_EQ(out_grid_desc.get_length(I1), this->N_ * this->Do_ * this->Ho_ * this->Wo_);
}
}
}
@@ -378,18 +428,24 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, ABCGridDescriptors)
// Test combined ABC grid descriptors
const auto abc_descriptors = transform.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>();
const auto& out_desc = abc_descriptors[I0];
const auto& in_desc = abc_descriptors[I1];
const auto& wei_desc = abc_descriptors[I2];
const auto& out_desc = abc_descriptors[I0]; // GEMM A ()
const auto& in_desc = abc_descriptors[I1]; // GEMM B (N_gemm, K_gemm)
const auto& wei_desc = abc_descriptors[I2]; // GEMM C
EXPECT_EQ(out_desc.get_num_of_dimension(), 3);
EXPECT_EQ(in_desc.get_num_of_dimension(), 2);
EXPECT_EQ(wei_desc.get_num_of_dimension(), 3);
// Verify the descriptors are correctly created
EXPECT_EQ(out_desc.get_length(I0), this->K_ * Gm);
EXPECT_EQ(wei_desc.get_length(I0), this->K_ * Gm);
// For input descriptor, verify the transformed dimensions
EXPECT_EQ(in_desc.get_length(I0), this->X_ * this->C_ * Gm);
EXPECT_EQ(in_desc.get_length(I1), this->N_ * this->Wo_);
EXPECT_EQ(in_desc.get_length(I0), this->X_ * this->C_);
EXPECT_EQ(in_desc.get_length(I1), this->N_ * this->Wo_ * Gm);
// // Verify GEMM M-dimension that should depend on Gm
// EXPECT_EQ(out_desc.get_length(I1), this->K_ * Gm);
// EXPECT_EQ(wei_desc.get_length(I0), this->K_ * Gm);
// // Verify GEMM N-dimension that should depend on Gm.
// EXPECT_EQ(in_desc.get_length(I1), this->X_ * this->C_ * Gm);
// EXPECT_EQ(wei_desc.get_length(I1), this->X_ * this->C_ * Gm);
}
else if constexpr (NDim == 2)
{
@@ -407,13 +463,11 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, ABCGridDescriptors)
const auto& in_desc = abc_descriptors[I1];
const auto& wei_desc = abc_descriptors[I2];
// Verify the descriptors are correctly created
EXPECT_EQ(out_desc.get_length(I0), this->K_ * Gm);
EXPECT_EQ(out_desc.get_length(I1), this->K_ * Gm);
EXPECT_EQ(wei_desc.get_length(I0), this->K_ * Gm);
// For input descriptor, verify the transformed dimensions
EXPECT_EQ(in_desc.get_length(I0), this->Y_ * this->X_ * this->C_ * Gm);
EXPECT_EQ(in_desc.get_length(I1), this->N_ * this->Ho_ * this->Wo_);
EXPECT_EQ(in_desc.get_length(I1), this->Y_ * this->X_ * this->C_ * Gm);
EXPECT_EQ(wei_desc.get_length(I1), this->Y_ *this->X_ * this->C_ * Gm);
}
else if constexpr (NDim == 3)
@@ -432,13 +486,11 @@ TYPED_TEST(TestTransformConvBwdWeightToGemm, ABCGridDescriptors)
const auto& in_desc = abc_descriptors[I1];
const auto& wei_desc = abc_descriptors[I2];
// Verify the descriptors are correctly created
EXPECT_EQ(out_desc.get_length(I0), this->K_ * Gm);
EXPECT_EQ(out_desc.get_length(I1), this->K_ * Gm);
EXPECT_EQ(wei_desc.get_length(I0), this->K_ * Gm);
// For input descriptor, verify the transformed dimensions
EXPECT_EQ(in_desc.get_length(I0), this->Z_ * this->Y_ * this->X_ * this->C_ * Gm);
EXPECT_EQ(in_desc.get_length(I1), this->N_ * this->Do_ * this->Ho_ * this->Wo_);
EXPECT_EQ(in_desc.get_length(I1), this->Z_ * this->Y_ * this->X_ * this->C_ * Gm);
EXPECT_EQ(wei_desc.get_length(I1), this->Z_ * this->Y_ *this->X_ * this->C_ * Gm);
}
}