diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 572f52b40d..bbbd248787 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -509,6 +509,7 @@ struct GroupedConvolutionForwardKernel static constexpr auto I1 = number<1>(); static constexpr auto I2 = number<2>(); static constexpr auto I3 = number<3>(); + static constexpr auto I5 = number<5>(); static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK, "Not supported!"); @@ -744,8 +745,9 @@ struct GroupedConvolutionForwardKernel if constexpr(std::is_same_v || std::is_same_v || std::is_same_v) { - // Check access per C - if(ConvC % GroupedConvTraitsType_::VectorSizeA != 0) + // Check access for A tensor + if(ConvC % GroupedConvTraitsType_::VectorSizeA != 0 && + GroupedConvTraitsType_::NumGroupsToMerge == 1) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { @@ -753,6 +755,28 @@ struct GroupedConvolutionForwardKernel } return false; } + else if(GroupedConvTraitsType_::NumGroupsToMerge > 1) + { + if(ConvC != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("ConvC must be equal to 1 for NumGroupsToMerge > 1 to allow " + "vector reads on group dimension!"); + } + return false; + } + + const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}]; + if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!"); + } + return false; + } + } } else { @@ -794,12 +818,30 @@ struct GroupedConvolutionForwardKernel { if(ConvK % GroupedConvTraitsType_::VectorSizeC != 0) { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + // Try to read over G + if(GroupedConvTraitsType_::NumGroupsToMerge > 1) { - CK_TILE_ERROR( - "Conv K is not a multiple of vector store size for output image!"); + const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}]; + if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0 || + ConvG % GroupedConvTraitsType_::VectorSizeC != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge to allow " + "writing over G dimension"); + } + return false; + } + } + else + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "ConvK is not a multiple of vector store size for output image!"); + } + return false; } - return false; } } else @@ -813,6 +855,18 @@ struct GroupedConvolutionForwardKernel if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1) { + // currently group merging works only for C == 1 due to tensor transformation + // limitations + if(ConvC != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("ConvC must be equal to 1 for NumGroupsToMerge > 1 to allow " + "vector reads on group dimension!"); + } + return false; + } + const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}]; if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0) { diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp index 8bea7f653c..54fec53d56 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp @@ -660,8 +660,8 @@ struct TransformConvFwdToGemm else { const auto in_n_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N_, Wi_, NumGroupsToMerge, C_), - make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_), + make_tuple(N_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_), number{}, I1); @@ -669,26 +669,24 @@ struct TransformConvFwdToGemm in_n_wi_c_desc, make_tuple(make_pass_through_transform(N_), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); const auto in_n_x_wo_c_desc = transform_tensor_descriptor( in_n_wip_c_desc, make_tuple(make_pass_through_transform(N_), make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}, sequence<4>{})); + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); return transform_tensor_descriptor( in_n_x_wo_c_desc, make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), - make_merge_transform(make_tuple(X_, C_))), - make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}), + make_merge_transform(make_tuple(X_))), + make_tuple(sequence<0, 2, 3>{}, sequence<1>{}), make_tuple(sequence<0>{}, sequence<1>{})); } } @@ -906,11 +904,10 @@ struct TransformConvFwdToGemm } else { - + // IsSupported ensures C == 1 to allow reading on G dimension const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( - make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), - make_tuple( - NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_), + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_), number{}, I1); @@ -919,12 +916,9 @@ struct TransformConvFwdToGemm make_tuple(make_pass_through_transform(N_), make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C_)), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( in_n_hip_wip_groups_c_desc, @@ -933,21 +927,15 @@ struct TransformConvFwdToGemm make_tuple(ConvDilationH_, ConvStrideH_)), make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C_)), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, - sequence<1, 2>{}, - sequence<3, 4>{}, - sequence<5>{}, - sequence<6>{})); + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); return transform_tensor_descriptor( in_n_y_ho_x_wo_groups_c_desc, make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), - make_merge_transform(make_tuple(Y_, X_, C_))), - make_tuple(sequence<0, 2, 4, 5>{}, sequence<1, 3, 6>{}), + make_merge_transform(make_tuple(Y_, X_))), + make_tuple(sequence<0, 2, 4, 5>{}, sequence<1, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); } } @@ -1214,14 +1202,10 @@ struct TransformConvFwdToGemm } else { + // IsSupported ensures C == 1 to allow reading on G dimension const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_), - make_tuple(NStrideTensorA_, - DiStride_, - HiStride_, - WiStride_, - GStrideTensorA_, - CStrideTensorA_), + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_), number{}, I1); @@ -1231,20 +1215,11 @@ struct TransformConvFwdToGemm make_pad_transform(Di_, InLeftPadD_, InRightPadD_), make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{})); + make_pass_through_transform(NumGroupsToMerge)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( in_n_hip_wip_c_desc, @@ -1255,27 +1230,21 @@ struct TransformConvFwdToGemm make_tuple(ConvDilationH_, ConvStrideH_)), make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5, 6>{}, - sequence<7>{}, - sequence<8>{})); + sequence<7>{})); return transform_tensor_descriptor( in_n_z_do_y_ho_x_wo_c_desc, make_tuple( make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), - make_merge_transform(make_tuple(Z_, Y_, X_, C_))), - make_tuple(sequence<0, 2, 4, 6, 7>{}, sequence<1, 3, 5, 8>{}), + make_merge_transform(make_tuple(Z_, Y_, X_))), + make_tuple(sequence<0, 2, 4, 6, 7>{}, sequence<1, 3, 5>{}), make_tuple(sequence<0>{}, sequence<1>{})); } }