diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp index 0b4744a3a1..0b290a474c 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp @@ -8,7 +8,7 @@ namespace ck_tile { template 1) + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) { - const auto BatchStride = C_; - return make_naive_tensor_descriptor(make_tuple(N_, Wi_, NumGroupsToMerge, C_), - make_tuple(NStride, WiStride, BatchStride, CStride), - number{}, - I1); + if constexpr(NumGroupsToMerge > 1) + { + const index_t BatchStride = C_; + return make_naive_tensor_descriptor(make_tuple(N_ * Wi_, NumGroupsToMerge, C_), + make_tuple(WiStride, BatchStride, CStride), + number{}, + I1); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_ * Wi_, C_), + make_tuple(WiStride, CStride), + number{}, + I1); + } } else { + if constexpr(NumGroupsToMerge > 1) + { + const index_t BatchStride = C_; + return make_naive_tensor_descriptor( + make_tuple(N_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStride, WiStride, BatchStride, CStride), + number{}, + I1); + } + else + { - return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_), - make_tuple(NStride, WiStride, CStride), - number{}, - I1); + return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_), + make_tuple(NStride, WiStride, CStride), + number{}, + I1); + } } } @@ -551,21 +573,44 @@ struct TransformConvBwdWeightToGemm const index_t WiStride = G_ * C_; constexpr auto CStride = I1; - if constexpr(NumGroupsToMerge > 1) + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) { - const auto BatchStride = C_; - return make_naive_tensor_descriptor( - make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), // K_Gm_N - make_tuple(NStride, HiStride, WiStride, BatchStride, CStride), - number{}, - I1); + if constexpr(NumGroupsToMerge > 1) + { + const index_t BatchStride = C_; + return make_naive_tensor_descriptor( + make_tuple(N_ * Hi_ * Wi_, NumGroupsToMerge, C_), // K_Gm_N + make_tuple(WiStride, BatchStride, CStride), + number{}, + I1); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_ * Hi_ * Wi_, C_), // K_N + make_tuple(WiStride, CStride), + number{}, + I1); + } } else { - return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), // K_N - make_tuple(NStride, HiStride, WiStride, CStride), - number{}, - I1); + if constexpr(NumGroupsToMerge > 1) + { + const index_t BatchStride = C_; + return make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), // K_Gm_N + make_tuple(NStride, HiStride, WiStride, BatchStride, CStride), + number{}, + I1); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, C_), // K_N + make_tuple(NStride, HiStride, WiStride, CStride), + number{}, + I1); + } } } @@ -664,22 +709,44 @@ struct TransformConvBwdWeightToGemm const index_t WiStride = G_ * C_; constexpr auto CStride = I1; - if constexpr(NumGroupsToMerge > 1) + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) { - const index_t BatchStride = C_; - return make_naive_tensor_descriptor( - make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_), - make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride), - number{}, - I1); + if constexpr(NumGroupsToMerge > 1) + { + const index_t BatchStride = C_; + return make_naive_tensor_descriptor( + make_tuple(N_ * Di_ * Hi_ * Wi_, NumGroupsToMerge, C_), + make_tuple(WiStride, BatchStride, CStride), + number{}, + I1); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_ * Di_ * Hi_ * Wi_, C_), + make_tuple(WiStride, CStride), + number{}, + I1); + } } else { - return make_naive_tensor_descriptor( - make_tuple(N_, Di_, Hi_, Wi_, C_), - make_tuple(NStride, DiStride, HiStride, WiStride, CStride), - number{}, - I1); + if constexpr(NumGroupsToMerge > 1) + { + const index_t BatchStride = C_; + return make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride), + number{}, + I1); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, C_), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride), + number{}, + I1); + } } } @@ -755,83 +822,111 @@ struct TransformConvBwdWeightToGemm const auto wei_grid_desc = make_wei_grid_desc(); // B: input tensor comes in K_N - if constexpr(NumGroupsToMerge > 1) + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) { - // Output tensor transformation - // [0, 1, 2] -> [0, 1] - // [(N*Wo), Gm, K] -> [(N*Wo), (Gm*K)] - const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor( - out_grid_desc, - make_tuple(make_pass_through_transform(N_ * Wo_), - make_merge_transform(make_tuple(NumGroupsToMerge, K_))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + if constexpr(NumGroupsToMerge > 1) + { + const auto out_grid_merged_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_ * Wo_), + make_merge_transform(make_tuple(NumGroupsToMerge, K_))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - // Input tensor transformation, part 1. - // [N, Wi, Gm, C] -> [N, (Wi + InLeftPadW + InRightPadW), Gm, C] = [N, Wip, Gm, C] - const auto in_n_wip_gm_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + const auto in_grid_merged_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_ * Wo_), + make_merge_transform(make_tuple(NumGroupsToMerge, C_))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - // Input tensor transformation, part 2. - // [N, Wip, Gm, C] -> [N, X, Wo, Gm, C] - const auto in_n_x_wo_gm_c_grid_desc = transform_tensor_descriptor( - in_n_wip_gm_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(X_, Wo_), - make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}, sequence<4>{})); - - // Input tensor transformation, part 3. - // [0, 1, 2, 3, 4] -> [0, 1] - // [N, X, Wo, Gm, C] -> [(N*Wo), (Gm*X*C)] - const auto in_gemm_n_gemm_k_grid_desc = transform_tensor_descriptor( - in_n_x_wo_gm_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(X_, NumGroupsToMerge, C_)), - make_merge_transform(make_tuple(N_, Wo_))), - make_tuple(sequence<1, 3, 4>{}, sequence<0, 2>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - - return make_tuple( - out_gemm_k_gemm_m_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_grid_desc); + return make_tuple(out_grid_merged_desc, in_grid_merged_desc, wei_grid_desc); + } + else + { + return make_tuple(out_grid_desc, in_grid_desc, wei_grid_desc); + } } else { - // [N, Wi, C] -> [N, (Wi + InLeftPadW + InRightPadW), C] = [N, Wip, C] - const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + if constexpr(NumGroupsToMerge > 1) + { + // Output tensor transformation + // [0, 1, 2] -> [0, 1] + // [(N*Wo), Gm, K] -> [(N*Wo), (Gm*K)] + const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_ * Wo_), + make_merge_transform(make_tuple(NumGroupsToMerge, K_))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - // [N, Wip, C] -> [N, X, Wo, C] - const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(X_, Wo_), - make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + // Input tensor transformation, part 1. + // [N, Wi, Gm, C] -> [N, (Wi + InLeftPadW + InRightPadW), Gm, C] = [N, Wip, Gm, C] + const auto in_n_wip_gm_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); - const auto in_gemmn_gemmktotal_grid_desc = - transform_tensor_descriptor(in_n_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(X_, C_)), - make_merge_transform(make_tuple(N_, Wo_))), - make_tuple(sequence<1, 3>{}, sequence<0, 2>{}), - make_tuple(sequence<1>{}, sequence<0>{})); + // Input tensor transformation, part 2. + // [N, Wip, Gm, C] -> [N, X, Wo, Gm, C] + const auto in_n_x_wo_gm_c_grid_desc = transform_tensor_descriptor( + in_n_wip_gm_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}, sequence<4>{})); - return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc); + // Input tensor transformation, part 3. + // [0, 1, 2, 3, 4] -> [0, 1] + // [N, X, Wo, Gm, C] -> [(N*Wo), (Gm*X*C)] + const auto in_gemm_n_gemm_k_grid_desc = transform_tensor_descriptor( + in_n_x_wo_gm_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(X_, NumGroupsToMerge, C_)), + make_merge_transform(make_tuple(N_, Wo_))), + make_tuple(sequence<1, 3, 4>{}, sequence<0, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return make_tuple( + out_gemm_k_gemm_m_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_grid_desc); + } + else + { + // [N, Wi, C] -> [N, (Wi + InLeftPadW + InRightPadW), C] = [N, Wip, C] + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + // [N, Wip, C] -> [N, X, Wo, C] + const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor( + in_n_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(X_, C_)), + make_merge_transform(make_tuple(N_, Wo_))), + make_tuple(sequence<1, 3>{}, sequence<0, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc); + } } } @@ -843,94 +938,122 @@ struct TransformConvBwdWeightToGemm const auto wei_grid_desc = make_wei_grid_desc(); // B: input tensor comes in K_N - if constexpr(NumGroupsToMerge > 1) + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) { - // Output tensor transformation - // [0, 1, 2] -> [0, 1] - // [(N*Ho*Wo), Gm, K] -> [(N*Ho*Wo), (K*Gm)] - const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor( - out_grid_desc, - make_tuple(make_pass_through_transform(N_ * Ho_ * Wo_), - make_merge_transform(make_tuple(NumGroupsToMerge, K_))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + if constexpr(NumGroupsToMerge > 1) + { + const auto out_grid_merged_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_ * Ho_ * Wo_), + make_merge_transform(make_tuple(NumGroupsToMerge, K_))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - // Input tensor transformation, part 1. - // [N, Hi, Wi, Gm, C] -> [N, Hip, Wip, Gm, C] - const auto in_n_hip_wip_gm_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C_)), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + const auto in_grid_merged_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_ * Ho_ * Wo_), + make_merge_transform(make_tuple(NumGroupsToMerge, C_))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - // Input tensor transformation, part 2. - // [N, Hip, Wip, Gm, C] -> [N, (Y, Wo), (X, Wo), Gm, C] - const auto in_n_y_ho_x_wo_gm_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_gm_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(Y_, Ho_), - make_tuple(ConvDilationH_, ConvStrideH_)), - make_embed_transform(make_tuple(X_, Wo_), - make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C_)), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, - sequence<1, 2>{}, - sequence<3, 4>{}, - sequence<5>{}, - sequence<6>{})); - - // Input tensor transformation, part 3. - // [0, 1, 2, 3, 4 5 6] -> [0, 1] - // [N, Y, Ho, X, Wo, Gm, C] -> [(N*Ho*Wo), (Gm*Y*X*C)] - const auto in_gemm_n_gemm_k_grid_desc = transform_tensor_descriptor( - in_n_y_ho_x_wo_gm_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y_, X_, NumGroupsToMerge, C_)), - make_merge_transform(make_tuple(N_, Ho_, Wo_))), - make_tuple(sequence<1, 3, 5, 6>{}, sequence<0, 2, 4>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - - return make_tuple( - out_gemm_k_gemm_m_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_grid_desc); + return make_tuple(out_grid_merged_desc, in_grid_merged_desc, wei_grid_desc); + } + else + { + return make_tuple(out_grid_desc, in_grid_desc, wei_grid_desc); + } } else { - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + if constexpr(NumGroupsToMerge > 1) + { + // Output tensor transformation + // [0, 1, 2] -> [0, 1] + // [(N*Ho*Wo), Gm, K] -> [(N*Ho*Wo), (K*Gm)] + const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_ * Ho_ * Wo_), + make_merge_transform(make_tuple(NumGroupsToMerge, K_))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(Y_, Ho_), - make_tuple(ConvDilationH_, ConvStrideH_)), - make_embed_transform(make_tuple(X_, Wo_), - make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + // Input tensor transformation, part 1. + // [N, Hi, Wi, Gm, C] -> [N, Hip, Wip, Gm, C] + const auto in_n_hip_wip_gm_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); - const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor( - in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)), - make_merge_transform(make_tuple(N_, Ho_, Wo_))), - make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}), - make_tuple(sequence<1>{}, sequence<0>{})); + // Input tensor transformation, part 2. + // [N, Hip, Wip, Gm, C] -> [N, (Y, Wo), (X, Wo), Gm, C] + const auto in_n_y_ho_x_wo_gm_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_gm_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5>{}, + sequence<6>{})); - return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc); + // Input tensor transformation, part 3. + // [0, 1, 2, 3, 4 5 6] -> [0, 1] + // [N, Y, Ho, X, Wo, Gm, C] -> [(N*Ho*Wo), (Gm*Y*X*C)] + const auto in_gemm_n_gemm_k_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_gm_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y_, X_, NumGroupsToMerge, C_)), + make_merge_transform(make_tuple(N_, Ho_, Wo_))), + make_tuple(sequence<1, 3, 5, 6>{}, sequence<0, 2, 4>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return make_tuple( + out_gemm_k_gemm_m_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_grid_desc); + } + else + { + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)), + make_merge_transform(make_tuple(N_, Ho_, Wo_))), + make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc); + } } } @@ -942,120 +1065,148 @@ struct TransformConvBwdWeightToGemm const auto wei_grid_desc = make_wei_grid_desc(); // B: input tensor comes in K_N - if constexpr(NumGroupsToMerge > 1) + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) { - // Output tensor transformation - // [0, 1, 2] -> [0, 1] - // [(N*Do*Ho*Wo), Gm, K] -> [(N*Do*Ho*Wo), (K*Gm)] - const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor( - out_grid_desc, - make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_), - make_merge_transform(make_tuple(NumGroupsToMerge, K_))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + if constexpr(NumGroupsToMerge > 1) + { + const auto out_grid_merged_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_), + make_merge_transform(make_tuple(NumGroupsToMerge, K_))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - // Input tensor transformation, part 1. - // [N, Di, Hi, Wi, Gm, C] -> [N, Dip, Hip, Wip, Gm, C] - const auto in_n_dip_hip_wip_gm_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Di_, InLeftPadD_, InRightPadD_), - make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{})); + const auto in_grid_merged_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_), + make_merge_transform(make_tuple(NumGroupsToMerge, C_))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - // Input tensor transformation, part 2. - // [N, Zip, Hip, Wip, Gm, C] -> [N, (Z, Zo), (Y, Wo), (X, Wo), Gm, C] - const auto in_n_z_do_y_ho_x_wo_gm_c_grid_desc = transform_tensor_descriptor( - in_n_dip_hip_wip_gm_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(Z_, Do_), - make_tuple(ConvDilationD_, ConvStrideD_)), - make_embed_transform(make_tuple(Y_, Ho_), - make_tuple(ConvDilationH_, ConvStrideH_)), - make_embed_transform(make_tuple(X_, Wo_), - make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<0>{}, - sequence<1, 2>{}, - sequence<3, 4>{}, - sequence<5, 6>{}, - sequence<7>{}, - sequence<8>{})); - - // Input tensor transformation, part 3. - // [0, 1, 2, 3, 4, 5, 6, 7, 8] -> [0, 1] - // [N, Z, Do, Y, Ho, X, Wo, Gm, C] -> [(N*Do*Ho*Wo), (Z*Y*X*Gm*C)] - const auto in_gemm_k_gemm_n_grid_desc = transform_tensor_descriptor( - in_n_z_do_y_ho_x_wo_gm_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, NumGroupsToMerge, C_)), - make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))), - make_tuple(sequence<1, 3, 5, 7, 8>{}, sequence<0, 2, 4, 6>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - - return make_tuple( - out_gemm_k_gemm_m_grid_desc, in_gemm_k_gemm_n_grid_desc, wei_grid_desc); + return make_tuple(out_grid_merged_desc, in_grid_merged_desc, wei_grid_desc); + } + else + { + return make_tuple(out_grid_desc, in_grid_desc, wei_grid_desc); + } } else { - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Di_, InLeftPadD_, InRightPadD_), - make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(C_)), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + if constexpr(NumGroupsToMerge > 1) + { + // Output tensor transformation + // [0, 1, 2] -> [0, 1] + // [(N*Do*Ho*Wo), Gm, K] -> [(N*Do*Ho*Wo), (K*Gm)] + const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_), + make_merge_transform(make_tuple(NumGroupsToMerge, K_))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(Z_, Do_), - make_tuple(ConvDilationD_, ConvStrideD_)), - make_embed_transform(make_tuple(Y_, Ho_), - make_tuple(ConvDilationH_, ConvStrideH_)), - make_embed_transform(make_tuple(X_, Wo_), - make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(C_)), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, - sequence<1, 2>{}, - sequence<3, 4>{}, - sequence<5, 6>{}, - sequence<7>{})); + // Input tensor transformation, part 1. + // [N, Di, Hi, Wi, Gm, C] -> [N, Dip, Hip, Wip, Gm, C] + const auto in_n_dip_hip_wip_gm_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{})); - const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor( - in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, C_)), - make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))), - make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}), - make_tuple(sequence<1>{}, sequence<0>{})); + // Input tensor transformation, part 2. + // [N, Zip, Hip, Wip, Gm, C] -> [N, (Z, Zo), (Y, Wo), (X, Wo), Gm, C] + const auto in_n_z_do_y_ho_x_wo_gm_c_grid_desc = transform_tensor_descriptor( + in_n_dip_hip_wip_gm_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Z_, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{}, + sequence<8>{})); - return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc); + // Input tensor transformation, part 3. + // [0, 1, 2, 3, 4, 5, 6, 7, 8] -> [0, 1] + // [N, Z, Do, Y, Ho, X, Wo, Gm, C] -> [(N*Do*Ho*Wo), (Z*Y*X*Gm*C)] + const auto in_gemm_k_gemm_n_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_gm_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, NumGroupsToMerge, C_)), + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))), + make_tuple(sequence<1, 3, 5, 7, 8>{}, sequence<0, 2, 4, 6>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return make_tuple( + out_gemm_k_gemm_m_grid_desc, in_gemm_k_gemm_n_grid_desc, wei_grid_desc); + } + else + { + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Z_, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, C_)), + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))), + make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc); + } } }