From 558054eadb461dab0a19ca99212f93268cd2e730 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Fri, 26 Sep 2025 13:38:24 +0000 Subject: [PATCH] WIP: Simplify conv to gemm transformations and handle K > 1 and C > 1 cases. --- .../transform_conv_bwd_weight_to_gemm.hpp | 172 ++++-------------- 1 file changed, 40 insertions(+), 132 deletions(-) diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp index 69d9c00161..22aaf5d360 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp @@ -217,7 +217,9 @@ struct TransformConvBwdWeightToGemm InRightPadD_{I0}, InRightPadH_{input_right_pads[I0]}, InRightPadW_{input_right_pads[I1]}, - ZYX_{Y_ * X_} + ZYX_{Y_ * X_}, + Kmerged_{K_}, + Cmerged_{C_} { static_assert(std::is_same_v> || std::is_same_v>); @@ -235,6 +237,13 @@ struct TransformConvBwdWeightToGemm } #endif N_ = c_g_n_k_wos_lengths[I1]; + + // Group merging + if constexpr (NumGroupsToMerge > 1) + { + Cmerged_ = integer_divide_ceil(C_, NumGroupsToMerge) * NumGroupsToMerge; + Kmerged_ = integer_divide_ceil(K_, NumGroupsToMerge) * NumGroupsToMerge; + } } template 1) - { - const index_t KStride = G_; - constexpr auto GStride = I1; - return make_naive_tensor_descriptor( - make_tuple(NumGroupsToMerge, K_, N_ * Ho_ * Wo_), - make_tuple(GStride, KStride, NDoHoWoStride)); - } - else - { - constexpr auto KStride = I1; - return make_naive_tensor_descriptor( - make_tuple(K_, N_ * Ho_ * Wo_), + constexpr auto KStride = I1; + + return make_naive_tensor_descriptor( + make_tuple(Kmerged_, N_ * Ho_ * Wo_), make_tuple(KStride, NDoHoWoStride)); - } } template ::type = false> @@ -515,22 +513,11 @@ struct TransformConvBwdWeightToGemm const index_t NStride = Hi_ * Wi_ * G_ * C_; const index_t HiStride = Wi_ * G_ * C_; const index_t WiStride = G_ * C_; + constexpr auto CStride = I1; - if constexpr (NumGroupsToMerge > 1) - { - const index_t CStride = G_; - constexpr auto GStride = I1; - return make_naive_tensor_descriptor( - make_tuple(N_, Hi_, Wi_, C_, NumGroupsToMerge), - make_tuple(NStride, HiStride, WiStride, CStride, GStride)); - } - else - { - constexpr auto CStride = I1; - return make_naive_tensor_descriptor( - make_tuple(N_, Hi_, Wi_, C_), + return make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, Cmerged_), make_tuple(NStride, HiStride, WiStride, CStride)); - } } template ::type = false> @@ -539,20 +526,10 @@ struct TransformConvBwdWeightToGemm // GKYXC const index_t KStride = Y_ * X_ * C_; constexpr auto CStride = I1; - - if constexpr (NumGroupsToMerge > 1) - { - const index_t GStride = K_ * Y_ * X_ * C_; - return make_naive_tensor_descriptor( - make_tuple(NumGroupsToMerge, K_, Y_ * X_ * C_), - make_tuple(GStride, KStride, CStride)); - } - else - { - return make_naive_tensor_descriptor( - make_tuple(K_, Y_ * X_ * C_), + + return make_naive_tensor_descriptor( + make_tuple(Kmerged_, Y_ * X_ * C_), make_tuple(KStride, CStride)); - } } ////////////////// @@ -742,104 +719,33 @@ struct TransformConvBwdWeightToGemm const auto in_grid_desc = make_in_grid_desc(); const auto wei_grid_desc = make_wei_grid_desc(); - // B: input tensor comes in K_N - if constexpr (NumGroupsToMerge > 1) - { - // Output tensor transformation - // [0, 1, 2] -> [0, 1] - // [Gm, K, (N*Ho*Wo)] -> [(K*Gm), (N*Ho*Wo)] - const auto out_gemm_m_gemm_k_grid_desc = - transform_tensor_descriptor( - out_grid_desc, - make_tuple( - make_merge_transform(make_tuple(NumGroupsToMerge, K_)), - make_pass_through_transform(N_ * Ho_ * Wo_)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - //[Gm, K, Y*X*C] -> [Gm*K, Y*X*C] - const auto wei_gemm_m_gemm_n_grid_desc = transform_tensor_descriptor( - wei_grid_desc, - make_tuple( - make_merge_transform(make_tuple(NumGroupsToMerge, K_)), - make_pass_through_transform(Y_ * X_ * C_)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - // Input tensor transformation, part 1. - // [N, Hi, Wi, C, Gm] -> [N, Hip, Wip, C, Gm] - const auto in_n_hip_wip_c_gm_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple( - make_pass_through_transform(N_), - make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(C_), - make_pass_through_transform(NumGroupsToMerge)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); - - // Input tensor transformation, part 2. - // [N, Hip, Wip, C, Gm] -> [N, (Y, Wo), (X, Wo), C, Gm] - const auto in_n_y_ho_x_wo_c_gm_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_gm_grid_desc, - make_tuple( - make_pass_through_transform(N_), - make_embed_transform( - make_tuple(Y_, Ho_), - make_tuple(ConvDilationH_, ConvStrideH_)), - make_embed_transform( - make_tuple(X_, Wo_), - make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(C_), - make_pass_through_transform(NumGroupsToMerge)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}, sequence<6>{})); - - // Input tensor transformation, part 3. - // [0, 1, 2, 3, 4 5 6] -> [0, 1] - // [N, Y, Ho, X, Wo, C, Gm] -> [(Gm*Y*X*C), (N*Ho*Wo)] - const auto in_gemm_n_gemm_k_grid_desc = - transform_tensor_descriptor( - in_n_y_ho_x_wo_c_gm_grid_desc, - make_tuple( - make_merge_transform(make_tuple(Y_, X_, C_, NumGroupsToMerge)), - make_merge_transform(make_tuple(N_, Ho_, Wo_))), - make_tuple(sequence<1, 3, 5, 6>{}, sequence<0, 2, 4>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return make_tuple(out_gemm_m_gemm_k_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_gemm_m_gemm_n_grid_desc); - } - else - { - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( in_grid_desc, make_tuple(make_pass_through_transform(N_), make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(C_)), + make_pass_through_transform(Cmerged_)), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple( - make_pass_through_transform(N_), - make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)), - make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(Cmerged_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); - const auto in_gemmn_gemmktotal_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)), - make_merge_transform(make_tuple(N_, Ho_, Wo_))), - make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + const auto in_gemmn_gemmktotal_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y_, X_, Cmerged_)), + make_merge_transform(make_tuple(N_, Ho_, Wo_))), + make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc); - } + return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc); } template ::type = false> @@ -969,6 +875,8 @@ struct TransformConvBwdWeightToGemm IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_; IndexType InRightPadD_, InRightPadH_, InRightPadW_; IndexType ZYX_; + IndexType Kmerged_; + IndexType Cmerged_; }; } // namespace ck_tile