mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
WIP: Simplify conv to gemm transformations and handle K > 1 and C > 1 cases.
This commit is contained in:
@@ -217,7 +217,9 @@ struct TransformConvBwdWeightToGemm
|
||||
InRightPadD_{I0},
|
||||
InRightPadH_{input_right_pads[I0]},
|
||||
InRightPadW_{input_right_pads[I1]},
|
||||
ZYX_{Y_ * X_}
|
||||
ZYX_{Y_ * X_},
|
||||
Kmerged_{K_},
|
||||
Cmerged_{C_}
|
||||
{
|
||||
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
|
||||
@@ -235,6 +237,13 @@ struct TransformConvBwdWeightToGemm
|
||||
}
|
||||
#endif
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
|
||||
// Group merging
|
||||
if constexpr (NumGroupsToMerge > 1)
|
||||
{
|
||||
Cmerged_ = integer_divide_ceil(C_, NumGroupsToMerge) * NumGroupsToMerge;
|
||||
Kmerged_ = integer_divide_ceil(K_, NumGroupsToMerge) * NumGroupsToMerge;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ConvDimsType,
|
||||
@@ -490,22 +499,11 @@ struct TransformConvBwdWeightToGemm
|
||||
{
|
||||
// NHWGK
|
||||
const index_t NDoHoWoStride = G_ * K_;
|
||||
|
||||
if constexpr (NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t KStride = G_;
|
||||
constexpr auto GStride = I1;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(NumGroupsToMerge, K_, N_ * Ho_ * Wo_),
|
||||
make_tuple(GStride, KStride, NDoHoWoStride));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto KStride = I1;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(K_, N_ * Ho_ * Wo_),
|
||||
constexpr auto KStride = I1;
|
||||
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Kmerged_, N_ * Ho_ * Wo_),
|
||||
make_tuple(KStride, NDoHoWoStride));
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -515,22 +513,11 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t NStride = Hi_ * Wi_ * G_ * C_;
|
||||
const index_t HiStride = Wi_ * G_ * C_;
|
||||
const index_t WiStride = G_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
if constexpr (NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t CStride = G_;
|
||||
constexpr auto GStride = I1;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, C_, NumGroupsToMerge),
|
||||
make_tuple(NStride, HiStride, WiStride, CStride, GStride));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto CStride = I1;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, C_),
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, Cmerged_),
|
||||
make_tuple(NStride, HiStride, WiStride, CStride));
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -539,20 +526,10 @@ struct TransformConvBwdWeightToGemm
|
||||
// GKYXC
|
||||
const index_t KStride = Y_ * X_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
if constexpr (NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t GStride = K_ * Y_ * X_ * C_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(NumGroupsToMerge, K_, Y_ * X_ * C_),
|
||||
make_tuple(GStride, KStride, CStride));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(K_, Y_ * X_ * C_),
|
||||
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Kmerged_, Y_ * X_ * C_),
|
||||
make_tuple(KStride, CStride));
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////
|
||||
@@ -742,104 +719,33 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto in_grid_desc = make_in_grid_desc<NDimSpatial>();
|
||||
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
|
||||
|
||||
// B: input tensor comes in K_N
|
||||
if constexpr (NumGroupsToMerge > 1)
|
||||
{
|
||||
// Output tensor transformation
|
||||
// [0, 1, 2] -> [0, 1]
|
||||
// [Gm, K, (N*Ho*Wo)] -> [(K*Gm), (N*Ho*Wo)]
|
||||
const auto out_gemm_m_gemm_k_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
|
||||
make_pass_through_transform(N_ * Ho_ * Wo_)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
//[Gm, K, Y*X*C] -> [Gm*K, Y*X*C]
|
||||
const auto wei_gemm_m_gemm_n_grid_desc = transform_tensor_descriptor(
|
||||
wei_grid_desc,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
|
||||
make_pass_through_transform(Y_ * X_ * C_)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// Input tensor transformation, part 1.
|
||||
// [N, Hi, Wi, C, Gm] -> [N, Hip, Wip, C, Gm]
|
||||
const auto in_n_hip_wip_c_gm_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_),
|
||||
make_pass_through_transform(NumGroupsToMerge)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
|
||||
// Input tensor transformation, part 2.
|
||||
// [N, Hip, Wip, C, Gm] -> [N, (Y, Wo), (X, Wo), C, Gm]
|
||||
const auto in_n_y_ho_x_wo_c_gm_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_gm_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N_),
|
||||
make_embed_transform(
|
||||
make_tuple(Y_, Ho_),
|
||||
make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(
|
||||
make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_),
|
||||
make_pass_through_transform(NumGroupsToMerge)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}, sequence<6>{}));
|
||||
|
||||
// Input tensor transformation, part 3.
|
||||
// [0, 1, 2, 3, 4 5 6] -> [0, 1]
|
||||
// [N, Y, Ho, X, Wo, C, Gm] -> [(Gm*Y*X*C), (N*Ho*Wo)]
|
||||
const auto in_gemm_n_gemm_k_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_gm_grid_desc,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(Y_, X_, C_, NumGroupsToMerge)),
|
||||
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5, 6>{}, sequence<0, 2, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemm_m_gemm_k_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_gemm_m_gemm_n_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_pass_through_transform(Cmerged_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(Cmerged_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
const auto in_gemmn_gemmktotal_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y_, X_, Cmerged_)),
|
||||
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
}
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
}
|
||||
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
@@ -969,6 +875,8 @@ struct TransformConvBwdWeightToGemm
|
||||
IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_;
|
||||
IndexType InRightPadD_, InRightPadH_, InRightPadW_;
|
||||
IndexType ZYX_;
|
||||
IndexType Kmerged_;
|
||||
IndexType Cmerged_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user