WIP: Simplify conv to gemm transformations and handle K > 1 and C > 1 cases.

This commit is contained in:
Ville Pietilä
2025-09-26 13:38:24 +00:00
parent 8babf7195a
commit 558054eadb

View File

@@ -217,7 +217,9 @@ struct TransformConvBwdWeightToGemm
InRightPadD_{I0},
InRightPadH_{input_right_pads[I0]},
InRightPadW_{input_right_pads[I1]},
ZYX_{Y_ * X_}
ZYX_{Y_ * X_},
Kmerged_{K_},
Cmerged_{C_}
{
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
@@ -235,6 +237,13 @@ struct TransformConvBwdWeightToGemm
}
#endif
N_ = c_g_n_k_wos_lengths[I1];
// Group merging
if constexpr (NumGroupsToMerge > 1)
{
Cmerged_ = integer_divide_ceil(C_, NumGroupsToMerge) * NumGroupsToMerge;
Kmerged_ = integer_divide_ceil(K_, NumGroupsToMerge) * NumGroupsToMerge;
}
}
template <typename ConvDimsType,
@@ -490,22 +499,11 @@ struct TransformConvBwdWeightToGemm
{
// NHWGK
const index_t NDoHoWoStride = G_ * K_;
if constexpr (NumGroupsToMerge > 1)
{
const index_t KStride = G_;
constexpr auto GStride = I1;
return make_naive_tensor_descriptor(
make_tuple(NumGroupsToMerge, K_, N_ * Ho_ * Wo_),
make_tuple(GStride, KStride, NDoHoWoStride));
}
else
{
constexpr auto KStride = I1;
return make_naive_tensor_descriptor(
make_tuple(K_, N_ * Ho_ * Wo_),
constexpr auto KStride = I1;
return make_naive_tensor_descriptor(
make_tuple(Kmerged_, N_ * Ho_ * Wo_),
make_tuple(KStride, NDoHoWoStride));
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
@@ -515,22 +513,11 @@ struct TransformConvBwdWeightToGemm
const index_t NStride = Hi_ * Wi_ * G_ * C_;
const index_t HiStride = Wi_ * G_ * C_;
const index_t WiStride = G_ * C_;
constexpr auto CStride = I1;
if constexpr (NumGroupsToMerge > 1)
{
const index_t CStride = G_;
constexpr auto GStride = I1;
return make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, C_, NumGroupsToMerge),
make_tuple(NStride, HiStride, WiStride, CStride, GStride));
}
else
{
constexpr auto CStride = I1;
return make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, C_),
return make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, Cmerged_),
make_tuple(NStride, HiStride, WiStride, CStride));
}
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
@@ -539,20 +526,10 @@ struct TransformConvBwdWeightToGemm
// GKYXC
const index_t KStride = Y_ * X_ * C_;
constexpr auto CStride = I1;
if constexpr (NumGroupsToMerge > 1)
{
const index_t GStride = K_ * Y_ * X_ * C_;
return make_naive_tensor_descriptor(
make_tuple(NumGroupsToMerge, K_, Y_ * X_ * C_),
make_tuple(GStride, KStride, CStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(K_, Y_ * X_ * C_),
return make_naive_tensor_descriptor(
make_tuple(Kmerged_, Y_ * X_ * C_),
make_tuple(KStride, CStride));
}
}
//////////////////
@@ -742,104 +719,33 @@ struct TransformConvBwdWeightToGemm
const auto in_grid_desc = make_in_grid_desc<NDimSpatial>();
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
// B: input tensor comes in K_N
if constexpr (NumGroupsToMerge > 1)
{
// Output tensor transformation
// [0, 1, 2] -> [0, 1]
// [Gm, K, (N*Ho*Wo)] -> [(K*Gm), (N*Ho*Wo)]
const auto out_gemm_m_gemm_k_grid_desc =
transform_tensor_descriptor(
out_grid_desc,
make_tuple(
make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
make_pass_through_transform(N_ * Ho_ * Wo_)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
//[Gm, K, Y*X*C] -> [Gm*K, Y*X*C]
const auto wei_gemm_m_gemm_n_grid_desc = transform_tensor_descriptor(
wei_grid_desc,
make_tuple(
make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
make_pass_through_transform(Y_ * X_ * C_)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// Input tensor transformation, part 1.
// [N, Hi, Wi, C, Gm] -> [N, Hip, Wip, C, Gm]
const auto in_n_hip_wip_c_gm_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(
make_pass_through_transform(N_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_),
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
// Input tensor transformation, part 2.
// [N, Hip, Wip, C, Gm] -> [N, (Y, Wo), (X, Wo), C, Gm]
const auto in_n_y_ho_x_wo_c_gm_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_gm_grid_desc,
make_tuple(
make_pass_through_transform(N_),
make_embed_transform(
make_tuple(Y_, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(
make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_),
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}, sequence<6>{}));
// Input tensor transformation, part 3.
// [0, 1, 2, 3, 4 5 6] -> [0, 1]
// [N, Y, Ho, X, Wo, C, Gm] -> [(Gm*Y*X*C), (N*Ho*Wo)]
const auto in_gemm_n_gemm_k_grid_desc =
transform_tensor_descriptor(
in_n_y_ho_x_wo_c_gm_grid_desc,
make_tuple(
make_merge_transform(make_tuple(Y_, X_, C_, NumGroupsToMerge)),
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
make_tuple(sequence<1, 3, 5, 6>{}, sequence<0, 2, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tuple(out_gemm_m_gemm_k_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_gemm_m_gemm_n_grid_desc);
}
else
{
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_pass_through_transform(Cmerged_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N_),
make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N_),
make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(Cmerged_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
const auto in_gemmn_gemmktotal_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)),
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto in_gemmn_gemmktotal_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y_, X_, Cmerged_)),
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
}
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
}
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
@@ -969,6 +875,8 @@ struct TransformConvBwdWeightToGemm
IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_;
IndexType InRightPadD_, InRightPadH_, InRightPadW_;
IndexType ZYX_;
IndexType Kmerged_;
IndexType Cmerged_;
};
} // namespace ck_tile