diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index 4b91382d10..c114b90ee2 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -1361,11 +1361,11 @@ struct TransformConvFwdToGemm { const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( - make_tuple(K_, NumGroupsToMerge, FilterSizeNumType{}), - make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_)); + make_tuple(NumGroupsToMerge, K_, FilterSizeNumType{}), + make_tuple(GStrideTensorB_, KStrideTensorB_, CStrideTensorB_)); return transform_tensor_descriptor( wei_gemmn_groups_gemmk_desc, - make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)), + make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)), make_pass_through_transform(FilterSizeNumType{})), make_tuple(Sequence<0, 1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -1495,16 +1495,16 @@ struct TransformConvFwdToGemm else { const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor( - make_tuple(N_, Wo_, NumGroupsToMerge, K_, 1), + make_tuple(N_, Wo_, NumGroupsToMerge, 1, K_), make_tuple( - NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_)); + NStrideTensorC_, WoStride_, GStrideTensorC_, GStrideTensorC_, KStrideTensorC_)); // Padd 1 to NumGroupsToMerge const auto padded_desc = transform_tensor_descriptor( nhwo_groups_k_1_desc, make_tuple(make_merge_transform(make_tuple(N_, Wo_)), make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(K_), - make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_pad_transform(1, 0, NumGroupsToMerge - 1), + make_pass_through_transform(K_)), make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); // We need only matrices from diagonal. X_or returns 0 for the same @@ -1518,13 +1518,13 @@ struct TransformConvFwdToGemm make_tuple(make_pass_through_transform(NDoHoWo), make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), make_pass_through_transform(K_)), - make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); // Merge To M, N return transform_tensor_descriptor( unmerged_padded_desc, make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), - make_merge_transform(make_tuple(K_, NumGroupsToMerge))), + make_merge_transform(make_tuple(NumGroupsToMerge, K_))), make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -1608,21 +1608,21 @@ struct TransformConvFwdToGemm else { const auto nhwo_groups_k_1_desc = - make_naive_tensor_descriptor(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, K_, 1), + make_naive_tensor_descriptor(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, 1, K_), make_tuple(NStrideTensorC_, DoStride_, HoStride_, WoStride_, GStrideTensorC_, - KStrideTensorC_, - GStrideTensorC_)); + GStrideTensorC_, + KStrideTensorC_)); // Padd 1 to NumGroupsToMerge const auto padded_desc = transform_tensor_descriptor( nhwo_groups_k_1_desc, make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(K_), - make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_pad_transform(1, 0, NumGroupsToMerge - 1), + make_pass_through_transform(K_)), make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}, Sequence<5>{}, Sequence<6>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); // We need only matrices from diagonal. X_or returns 0 for the same @@ -1636,13 +1636,13 @@ struct TransformConvFwdToGemm make_tuple(make_pass_through_transform(NDoHoWo), make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), make_pass_through_transform(K_)), - make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); // Merge To M, N return transform_tensor_descriptor( unmerged_padded_desc, make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), - make_merge_transform(make_tuple(K_, NumGroupsToMerge))), + make_merge_transform(make_tuple(NumGroupsToMerge, K_))), make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); }