From f0d724135c441f8b1310228bc4dc3b9f9a19a9ac Mon Sep 17 00:00:00 2001 From: jakpiase Date: Sun, 1 Mar 2026 14:26:17 +0100 Subject: [PATCH] [CK] CK Tile improvements and fixes for depthwise merged convolutions forward (#4873) ## Motivation Performance benchmarks showed that old CK's depthwise merged convolutions are much faster than CK Tile's ones. ## Technical Details After investigation it showed up that the requirement that A/CVectorload is a multiple of gemm's rightmost dimension is too strict in case of processing multiple groups, because if tensor is in NHWGC/NHWGK format, then if C/K is equal to 1, we can use vectorloads on the G dimension, which is added by this PR. Filter5x5 specialization was also added, because some models are using it, it's similar to 3x3, the only difference is the window size. This addition was needed, because of the differences of tensor descriptor transformations betweeen CK and CK Tile. In old CK the case of grouped depthwise 5x5 convs was supported via Default specialization, but in CK Tile that case was not working properly. ## Test Plan Performance was tested by our internal test suite, which contains several DL models. ## Test Result Tests results showed significant performance uplift for depthwise(3x3, 5x5) cases --------- Co-authored-by: Bartlomiej Kocot --- .../grouped_convolution_forward_kernel.hpp | 66 ++++++++++-- .../utils/transform_conv_fwd_to_gemm.hpp | 101 ++++++------------ 2 files changed, 95 insertions(+), 72 deletions(-) diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 572f52b40d..bbbd248787 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -509,6 +509,7 @@ struct GroupedConvolutionForwardKernel static constexpr auto I1 = number<1>(); static constexpr auto I2 = number<2>(); static constexpr auto I3 = number<3>(); + static constexpr auto I5 = number<5>(); static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK, "Not supported!"); @@ -744,8 +745,9 @@ struct GroupedConvolutionForwardKernel if constexpr(std::is_same_v || std::is_same_v || std::is_same_v) { - // Check access per C - if(ConvC % GroupedConvTraitsType_::VectorSizeA != 0) + // Check access for A tensor + if(ConvC % GroupedConvTraitsType_::VectorSizeA != 0 && + GroupedConvTraitsType_::NumGroupsToMerge == 1) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { @@ -753,6 +755,28 @@ struct GroupedConvolutionForwardKernel } return false; } + else if(GroupedConvTraitsType_::NumGroupsToMerge > 1) + { + if(ConvC != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("ConvC must be equal to 1 for NumGroupsToMerge > 1 to allow " + "vector reads on group dimension!"); + } + return false; + } + + const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}]; + if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!"); + } + return false; + } + } } else { @@ -794,12 +818,30 @@ struct GroupedConvolutionForwardKernel { if(ConvK % GroupedConvTraitsType_::VectorSizeC != 0) { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + // Try to read over G + if(GroupedConvTraitsType_::NumGroupsToMerge > 1) { - CK_TILE_ERROR( - "Conv K is not a multiple of vector store size for output image!"); + const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}]; + if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0 || + ConvG % GroupedConvTraitsType_::VectorSizeC != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge to allow " + "writing over G dimension"); + } + return false; + } + } + else + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "ConvK is not a multiple of vector store size for output image!"); + } + return false; } - return false; } } else @@ -813,6 +855,18 @@ struct GroupedConvolutionForwardKernel if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1) { + // currently group merging works only for C == 1 due to tensor transformation + // limitations + if(ConvC != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("ConvC must be equal to 1 for NumGroupsToMerge > 1 to allow " + "vector reads on group dimension!"); + } + return false; + } + const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}]; if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0) { diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp index 8bea7f653c..54fec53d56 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp @@ -660,8 +660,8 @@ struct TransformConvFwdToGemm else { const auto in_n_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N_, Wi_, NumGroupsToMerge, C_), - make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_), + make_tuple(N_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_), number{}, I1); @@ -669,26 +669,24 @@ struct TransformConvFwdToGemm in_n_wi_c_desc, make_tuple(make_pass_through_transform(N_), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); const auto in_n_x_wo_c_desc = transform_tensor_descriptor( in_n_wip_c_desc, make_tuple(make_pass_through_transform(N_), make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}, sequence<4>{})); + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); return transform_tensor_descriptor( in_n_x_wo_c_desc, make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), - make_merge_transform(make_tuple(X_, C_))), - make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}), + make_merge_transform(make_tuple(X_))), + make_tuple(sequence<0, 2, 3>{}, sequence<1>{}), make_tuple(sequence<0>{}, sequence<1>{})); } } @@ -906,11 +904,10 @@ struct TransformConvFwdToGemm } else { - + // IsSupported ensures C == 1 to allow reading on G dimension const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( - make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), - make_tuple( - NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_), + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_), number{}, I1); @@ -919,12 +916,9 @@ struct TransformConvFwdToGemm make_tuple(make_pass_through_transform(N_), make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C_)), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( in_n_hip_wip_groups_c_desc, @@ -933,21 +927,15 @@ struct TransformConvFwdToGemm make_tuple(ConvDilationH_, ConvStrideH_)), make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C_)), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, - sequence<1, 2>{}, - sequence<3, 4>{}, - sequence<5>{}, - sequence<6>{})); + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); return transform_tensor_descriptor( in_n_y_ho_x_wo_groups_c_desc, make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), - make_merge_transform(make_tuple(Y_, X_, C_))), - make_tuple(sequence<0, 2, 4, 5>{}, sequence<1, 3, 6>{}), + make_merge_transform(make_tuple(Y_, X_))), + make_tuple(sequence<0, 2, 4, 5>{}, sequence<1, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); } } @@ -1214,14 +1202,10 @@ struct TransformConvFwdToGemm } else { + // IsSupported ensures C == 1 to allow reading on G dimension const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_), - make_tuple(NStrideTensorA_, - DiStride_, - HiStride_, - WiStride_, - GStrideTensorA_, - CStrideTensorA_), + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_), number{}, I1); @@ -1231,20 +1215,11 @@ struct TransformConvFwdToGemm make_pad_transform(Di_, InLeftPadD_, InRightPadD_), make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{})); + make_pass_through_transform(NumGroupsToMerge)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( in_n_hip_wip_c_desc, @@ -1255,27 +1230,21 @@ struct TransformConvFwdToGemm make_tuple(ConvDilationH_, ConvStrideH_)), make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(NumGroupsToMerge), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5, 6>{}, - sequence<7>{}, - sequence<8>{})); + sequence<7>{})); return transform_tensor_descriptor( in_n_z_do_y_ho_x_wo_c_desc, make_tuple( make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), - make_merge_transform(make_tuple(Z_, Y_, X_, C_))), - make_tuple(sequence<0, 2, 4, 6, 7>{}, sequence<1, 3, 5, 8>{}), + make_merge_transform(make_tuple(Z_, Y_, X_))), + make_tuple(sequence<0, 2, 4, 6, 7>{}, sequence<1, 3, 5>{}), make_tuple(sequence<0>{}, sequence<1>{})); } }