From 23ce6028aba8999be9450da58e416a9bf9dfbd28 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Tue, 2 Dec 2025 11:37:26 +0100 Subject: [PATCH] [CK_TILE] Add indexing optimizations for conv bwd data (#3309) * add indexing optimizations for conv bwd data * fix formating [ROCm/composable_kernel commit: 59265d5eb2030c188bde9b6425e6a3bb8fc6f58f] --- .../utils/transform_conv_bwd_data_to_gemm.hpp | 913 ++++++++++-------- 1 file changed, 504 insertions(+), 409 deletions(-) diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp index 71c3dc4cdf..deb4dcb3db 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp @@ -8,7 +8,7 @@ namespace ck_tile { template 1 - return make_naive_tensor_descriptor(make_tuple(N_, Wo_, K_), - make_tuple(NStride, WoStride, KStride), - number{}, - I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor( + make_tuple(N_ * Wo_, K_), make_tuple(WoStride, KStride), number{}, I1); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_, Wo_, K_), + make_tuple(NStride, WoStride, KStride), + number{}, + I1); + } } template ::type = false> CK_TILE_HOST auto make_wei_grid_desc() const { // GKXC - return make_naive_tensor_descriptor( - make_tuple(K_, X_, C_), make_tuple(X_ * C_, C_, I1), number{}, I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor( + make_tuple(K_, C_), make_tuple(C_, I1), number{}, I1); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(K_, X_, C_), make_tuple(X_ * C_, C_, I1), number{}, I1); + } } template ::type = false> @@ -491,14 +507,22 @@ struct TransformConvBwdDataToGemm { // NWGC const index_t NStride = Wi_ * G_ * C_; - const index_t WiStride = G_ * C_; // GC? + const index_t WiStride = G_ * C_; constexpr auto CStride = I1; // TODO Add support for NumGroupsToMerge > 1 - return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_), - make_tuple(NStride, WiStride, CStride), - number{}, - I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor( + make_tuple(N_ * Wi_, C_), make_tuple(WiStride, CStride), number{}, I1); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_), + make_tuple(NStride, WiStride, CStride), + number{}, + I1); + } } template ::type = false> @@ -512,10 +536,20 @@ struct TransformConvBwdDataToGemm // TODO Add support for NumGroupsToMerge > 1 - return make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, K_), - make_tuple(NStride, HoStride, WoStride, KStride), - number{}, - I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), + make_tuple(WoStride, KStride), + number{}, + I1); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, K_), + make_tuple(NStride, HoStride, WoStride, KStride), + number{}, + I1); + } } template ::type = false> @@ -528,20 +562,38 @@ struct TransformConvBwdDataToGemm constexpr auto CStride = I1; // TODO Add support for NumGroupsToMerge > 1 - return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), - make_tuple(NStride, HiStride, WiStride, CStride), - number{}, - I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N_ * Hi_ * Wi_, C_), + make_tuple(WiStride, CStride), + number{}, + I1); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStride, HiStride, WiStride, CStride), + number{}, + I1); + } } template ::type = false> CK_TILE_HOST auto make_wei_grid_desc() const { // GKYXC - return make_naive_tensor_descriptor(make_tuple(K_, Y_, X_, C_), - make_tuple(C_ * X_ * Y_, C_ * X_, C_, I1), - number{}, - I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor( + make_tuple(K_, C_), make_tuple(C_, I1), number{}, I1); + } + else + { + return make_naive_tensor_descriptor(make_tuple(K_, Y_, X_, C_), + make_tuple(C_ * X_ * Y_, C_ * X_, C_, I1), + number{}, + I1); + } } template ::type = false> @@ -555,11 +607,21 @@ struct TransformConvBwdDataToGemm constexpr auto KStride = I1; // TODO Add support for NumGroupsToMerge > 1 - return make_naive_tensor_descriptor( - make_tuple(N_, Do_, Ho_, Wo_, K_), - make_tuple(NStride, DoStride, HoStride, WoStride, KStride), - number{}, - I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_), + make_tuple(WoStride, KStride), + number{}, + I1); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N_, Do_, Ho_, Wo_, K_), + make_tuple(NStride, DoStride, HoStride, WoStride, KStride), + number{}, + I1); + } } template ::type = false> @@ -612,103 +674,111 @@ struct TransformConvBwdDataToGemm const auto in_grid_desc = make_in_grid_desc(); const auto wei_grid_desc = make_wei_grid_desc(); - // A: output tensor comes in K_M - const auto out_n_wop_k_grid_desc = - transform_tensor_descriptor(out_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Wo_, I0, I0), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_tuple(out_grid_desc, wei_grid_desc, in_grid_desc); + } + else + { + // A: output tensor comes in K_M + const auto out_n_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wo_, I0, I0), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); - const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( - out_n_wop_k_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(XDot_, WTilde_), - make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - const auto out_n_xdotslice_wtildeslice_k_grid_desc = transform_tensor_descriptor( - out_n_xdot_wtilde_k_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_slice_transform(XDot_, I0, XDotSlice), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + const auto out_n_xdotslice_wtildeslice_k_grid_desc = transform_tensor_descriptor( + out_n_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); - const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( - out_n_xdotslice_wtildeslice_k_grid_desc, - make_tuple(make_merge_transform(make_tuple(XDotSlice, K_)), - make_merge_transform(make_tuple(N_, WTildeSlice))), - make_tuple(sequence<1, 3>{}, sequence<0, 2>{}), - make_tuple(sequence<1>{}, sequence<0>{})); + const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( + out_n_xdotslice_wtildeslice_k_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K_)), + make_merge_transform(make_tuple(N_, WTildeSlice))), + make_tuple(sequence<1, 3>{}, sequence<0, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); - // B: weight tensor comes in K_N - const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( - wei_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_embed_transform(make_tuple(XDot_, XTilde_), - make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + // B: weight tensor comes in K_N + const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_embed_transform(make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - const auto wei_k_xdotslice_c_grid_desc = transform_tensor_descriptor( - wei_k_xdot_xtilde_c_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_slice_transform(XDot_, I0, XDotSlice), - make_freeze_transform(IdxXTilde_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<>{}, sequence<2>{})); + const auto wei_k_xdotslice_c_grid_desc = transform_tensor_descriptor( + wei_k_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<>{}, sequence<2>{})); - const auto wei_gemmn_gemmkraw_grid_desc = - transform_tensor_descriptor(wei_k_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(XDotSlice, K_)), - make_pass_through_transform(C_)), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor( + wei_k_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K_)), + make_pass_through_transform(C_)), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - // c: input - const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + // c: input + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); - const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( - in_n_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(XTilde_, WTilde_), - make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor( - in_n_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_freeze_transform(IdxXTilde_), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<>{}, sequence<1>{}, sequence<2>{})); + const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<>{}, sequence<1>{}, sequence<2>{})); - const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( - in_n_wtildeslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N_, WTildeSlice)), - make_pass_through_transform(C_)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, WTildeSlice)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - return make_tuple(out_gemmm_gemmkraw_grid_desc, - wei_gemmn_gemmkraw_grid_desc, - in_gemmmraw_gemmnraw_grid_desc); + return make_tuple(out_gemmm_gemmkraw_grid_desc, + wei_gemmn_gemmkraw_grid_desc, + in_gemmmraw_gemmnraw_grid_desc); + } } template ::type = false> @@ -734,39 +804,135 @@ struct TransformConvBwdDataToGemm const auto XDotSlice = integer_divide_ceil(X_ - IdxXTilde_, XTilde_); const auto out_grid_desc = make_out_grid_desc(); - const auto in_grid_desc = make_in_grid_desc(); const auto wei_grid_desc = make_wei_grid_desc(); + const auto in_grid_desc = make_in_grid_desc(); - // A: output tensor comes in K_M - const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( - out_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Ho_, I0, I0), - make_pad_transform(Wo_, I0, I0), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); - - const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( - out_n_hop_wop_k_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(YDot_, HTilde_), - make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), - make_embed_transform(make_tuple(XDot_, WTilde_), - make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); - - const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = - transform_tensor_descriptor( - out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_tuple(out_grid_desc, wei_grid_desc, in_grid_desc); + } + else + { + // A: output tensor comes in K_M + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, make_tuple(make_pass_through_transform(N_), - make_slice_transform(YDot_, I0, YDotSlice), - make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), - make_slice_transform(XDot_, I0, XDotSlice), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(YDot_, HTilde_), + make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{})); + + const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), + make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))), + make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // B: weight tensor comes in K_N + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_embed_transform(make_tuple(YDot_, YTilde_), + make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + const auto wei_k_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxYTilde_), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<3>{}, + sequence<2>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<>{}, + sequence<>{}, + sequence<3>{})); + + const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor( + wei_k_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), + make_pass_through_transform(C_)), + make_tuple(sequence<1, 2, 0>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + // c: input + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(YTilde_, HTilde_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_freeze_transform(IdxYTilde_), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, @@ -774,111 +940,23 @@ struct TransformConvBwdDataToGemm sequence<4>{}, sequence<5>{}), make_tuple(sequence<0>{}, + sequence<>{}, sequence<1>{}, + sequence<>{}, sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{})); + sequence<3>{})); - const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), - make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))), - make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}), - make_tuple(sequence<1>{}, sequence<0>{})); + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - // B: weight tensor comes in K_N - const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( - wei_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_embed_transform(make_tuple(YDot_, YTilde_), - make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), - make_embed_transform(make_tuple(XDot_, XTilde_), - make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); - - const auto wei_k_ydotslice_xdotslice_c_grid_desc = - transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_slice_transform(YDot_, I0, YDotSlice), - make_slice_transform(XDot_, I0, XDotSlice), - make_freeze_transform(IdxYTilde_), - make_freeze_transform(IdxXTilde_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<3>{}, - sequence<2>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<>{}, - sequence<>{}, - sequence<3>{})); - - const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor( - wei_k_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), - make_pass_through_transform(C_)), - make_tuple(sequence<1, 2, 0>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - // c: input - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); - - const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(YTilde_, HTilde_), - make_tuple(ConvDilationH_, ConvStrideH_)), - make_embed_transform(make_tuple(XTilde_, WTilde_), - make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); - - const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( - in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_freeze_transform(IdxYTilde_), - make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), - make_freeze_transform(IdxXTilde_), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<0>{}, - sequence<>{}, - sequence<1>{}, - sequence<>{}, - sequence<2>{}, - sequence<3>{})); - - const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( - in_n_htildeslice_wtildeslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)), - make_pass_through_transform(C_)), - make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return make_tuple(out_gemmm_gemmkraw_grid_desc, - wei_gemmn_gemmkraw_grid_desc, - in_gemmmraw_gemmnraw_grid_desc); + return make_tuple(out_gemmm_gemmkraw_grid_desc, + wei_gemmn_gemmkraw_grid_desc, + in_gemmmraw_gemmnraw_grid_desc); + } } template ::type = false> @@ -914,45 +992,174 @@ struct TransformConvBwdDataToGemm const auto in_grid_desc = make_in_grid_desc(); const auto wei_grid_desc = make_wei_grid_desc(); - // A: output tensor comes in K_M - const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( - out_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Do_, I0, I0), - make_pad_transform(Ho_, I0, I0), - make_pad_transform(Wo_, I0, I0), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); - - const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( - out_n_hop_wop_k_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(ZDot_, DTilde_), - make_tuple(-ConvDilationD_ / GcdStrideDilationD_, I1)), - make_embed_transform(make_tuple(YDot_, HTilde_), - make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), - make_embed_transform(make_tuple(XDot_, WTilde_), - make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, - sequence<1, 2>{}, - sequence<3, 4>{}, - sequence<5, 6>{}, - sequence<7>{})); - - const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = - transform_tensor_descriptor( - out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_tuple(out_grid_desc, wei_grid_desc, in_grid_desc); + } + else + { + // A: output tensor comes in K_M + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, make_tuple(make_pass_through_transform(N_), - make_slice_transform(ZDot_, I0, ZDotSlice), - make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), - make_slice_transform(YDot_, I0, YDotSlice), - make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), - make_slice_transform(XDot_, I0, XDotSlice), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pad_transform(Do_, I0, I0), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), make_pass_through_transform(K_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(ZDot_, DTilde_), + make_tuple(-ConvDilationD_ / GcdStrideDilationD_, I1)), + make_embed_transform(make_tuple(YDot_, HTilde_), + make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_slice_transform(ZDot_, I0, ZDotSlice), + make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}, + sequence<6>{}, + sequence<7>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}, + sequence<6>{}, + sequence<7>{})); + + const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple( + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), + make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice))), + make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // B: weight tensor comes in K_N + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_embed_transform(make_tuple(ZDot_, ZTilde_), + make_tuple(ConvStrideD_ / GcdStrideDilationD_, I1)), + make_embed_transform(make_tuple(YDot_, YTilde_), + make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + const auto wei_k_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_slice_transform(ZDot_, I0, ZDotSlice), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxZTilde_), + make_freeze_transform(IdxYTilde_), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<3>{}, + sequence<5>{}, + sequence<2>{}, + sequence<4>{}, + sequence<6>{}, + sequence<7>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<>{}, + sequence<>{}, + sequence<>{}, + sequence<4>{})); + + const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor( + wei_k_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), + make_pass_through_transform(C_)), + make_tuple(sequence<1, 2, 3, 0>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + // c: input + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(ZTilde_, DTilde_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(YTilde_, HTilde_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_freeze_transform(IdxZTilde_), + make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(IdxYTilde_), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, @@ -962,138 +1169,26 @@ struct TransformConvBwdDataToGemm sequence<6>{}, sequence<7>{}), make_tuple(sequence<0>{}, + sequence<>{}, sequence<1>{}, + sequence<>{}, sequence<2>{}, + sequence<>{}, sequence<3>{}, - sequence<4>{}, - sequence<5>{}, - sequence<6>{}, - sequence<7>{})); + sequence<4>{})); - const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, - make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), - make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice))), - make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}), - make_tuple(sequence<1>{}, sequence<0>{})); + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - // B: weight tensor comes in K_N - const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( - wei_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_embed_transform(make_tuple(ZDot_, ZTilde_), - make_tuple(ConvStrideD_ / GcdStrideDilationD_, I1)), - make_embed_transform(make_tuple(YDot_, YTilde_), - make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), - make_embed_transform(make_tuple(XDot_, XTilde_), - make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, - sequence<1, 2>{}, - sequence<3, 4>{}, - sequence<5, 6>{}, - sequence<7>{})); - - const auto wei_k_ydotslice_xdotslice_c_grid_desc = - transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_slice_transform(ZDot_, I0, ZDotSlice), - make_slice_transform(YDot_, I0, YDotSlice), - make_slice_transform(XDot_, I0, XDotSlice), - make_freeze_transform(IdxZTilde_), - make_freeze_transform(IdxYTilde_), - make_freeze_transform(IdxXTilde_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<3>{}, - sequence<5>{}, - sequence<2>{}, - sequence<4>{}, - sequence<6>{}, - sequence<7>{}), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<>{}, - sequence<>{}, - sequence<>{}, - sequence<4>{})); - - const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor( - wei_k_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), - make_pass_through_transform(C_)), - make_tuple(sequence<1, 2, 3, 0>{}, sequence<4>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - // c: input - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Di_, InLeftPadD_, InRightPadD_), - make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); - - const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(ZTilde_, DTilde_), - make_tuple(ConvDilationD_, ConvStrideD_)), - make_embed_transform(make_tuple(YTilde_, HTilde_), - make_tuple(ConvDilationH_, ConvStrideH_)), - make_embed_transform(make_tuple(XTilde_, WTilde_), - make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, - sequence<1, 2>{}, - sequence<3, 4>{}, - sequence<5, 6>{}, - sequence<7>{})); - - const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( - in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_freeze_transform(IdxZTilde_), - make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), - make_freeze_transform(IdxYTilde_), - make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), - make_freeze_transform(IdxXTilde_), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}, - sequence<6>{}, - sequence<7>{}), - make_tuple(sequence<0>{}, - sequence<>{}, - sequence<1>{}, - sequence<>{}, - sequence<2>{}, - sequence<>{}, - sequence<3>{}, - sequence<4>{})); - - const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( - in_n_htildeslice_wtildeslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice)), - make_pass_through_transform(C_)), - make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return make_tuple(out_gemmm_gemmkraw_grid_desc, - wei_gemmn_gemmkraw_grid_desc, - in_gemmmraw_gemmnraw_grid_desc); + return make_tuple(out_gemmm_gemmkraw_grid_desc, + wei_gemmn_gemmkraw_grid_desc, + in_gemmmraw_gemmnraw_grid_desc); + } } IndexType G_, N_, original_N_;