From f1c40b8fbaee16a89a8564311d5bf4c8083abfd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <188998872+vpietila-amd@users.noreply.github.com> Date: Tue, 2 Dec 2025 15:23:32 +0200 Subject: [PATCH] [CK_TILE] Merge multiple fwd convolution groups into a single GEMM batch. (#3136) * Merge fwd conv groups in CK Tile. * Fix building CK fwd convs. * Add number of merged groups to conv fwd kernel name. * Get number of merged groups from conv config. * Rename GemmConfig to ConvConfig. * Clean-up TODOs. * Check that number of conv groups must be divisible by the number of merged groups. * Improve error handling in the conv fwd example. * Fix clang-format. * Fix group offsets. * Fix merge problem. * Address feedback from code review. * Fix clang-formatting. [ROCm/composable_kernel commit: 66832861ad78cc63584c32e5d231fd29a99c57b3] --- .../grouped_convolution_forward.cpp | 12 +- ...ped_convolution_backward_weight_kernel.hpp | 2 - .../grouped_convolution_forward_kernel.hpp | 127 ++++++++++++------ .../utils/transform_conv_fwd_to_gemm.hpp | 28 ++-- 4 files changed, 111 insertions(+), 58 deletions(-) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp index af9820df2d..d26aaa98e3 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp @@ -50,9 +50,17 @@ int run_grouped_conv_fwd_example(int argc, char* argv[]) int main(int argc, char* argv[]) { + try + { #if CK_TILE_USE_WMMA - return !run_grouped_conv_fwd_example(argc, argv); + return !run_grouped_conv_fwd_example(argc, argv); #else - return !run_grouped_conv_fwd_example(argc, argv); + return !run_grouped_conv_fwd_example(argc, argv); #endif + } + catch(const std::runtime_error& e) + { + std::cerr << "Runtime error: " << e.what() << '\n'; + return EXIT_FAILURE; + } } diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 5b7ce04638..6ef1d84a6e 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -643,8 +643,6 @@ struct GroupedConvolutionBackwardWeightKernel CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!"); return false; } - - // TODO: Should we also check that GemmM <= MPerBlock and GemmN <= NPerBlock? } return true; diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 94f69e5d91..72ba17c5a5 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -28,7 +28,6 @@ namespace ck_tile { template struct GroupedConvFwdKernelArgs { - using ConvToGemmFwdTransformer = TransformConvFwdToGemm(); - group_stride_a = args.C_; - group_stride_b = args.K_ * args.C_ * + NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; + group_stride_a = args.C_ * NumGroupsToMerge; + group_stride_b = args.K_ * args.C_ * NumGroupsToMerge * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end(), 1, std::multiplies()); - group_stride_c = args.K_; + group_stride_c = args.K_ * NumGroupsToMerge; // Initialize Split-N support fields for 1D convolution (NWGC layout) // Get the actual split N from transformer @@ -121,8 +120,20 @@ struct GroupedConvFwdKernelArgs input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0]; output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0]; - // Update GemmM to use split N (not original N) - GemmM = n_per_split * args.output_spatial_lengths_[0]; + GemmM = a_grid_desc_m_k.get_length(number<0>{}); + GemmN = b_grid_desc_n_k.get_length(number<0>{}); + GemmK = a_grid_desc_m_k.get_length(number<1>{}); + GemmBatch = integer_divide_ceil(args.G_, NumGroupsToMerge); + + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK + << ", GemmBatch: " << GemmBatch << ", N per split: " << n_per_split + << ", number of N splits: " << n_splits + << ", input_batch_stride: " << input_batch_stride + << ", output_batch_stride: " << output_batch_stride + << ", NumGroupsToMerge: " << NumGroupsToMerge << std::endl; + } } template < @@ -163,11 +174,6 @@ struct GroupedConvFwdKernelArgs k_batch = args.k_batch; - // Note: GemmM will be set after Split-N calculation - GemmN = args.K_; - GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1]; - GemmBatch = args.G_; - in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; for(index_t d = 0; d < NumDTensor; d++) @@ -192,13 +198,14 @@ struct GroupedConvFwdKernelArgs c_grid_desc_m_n = transformer_.template MakeCDescriptor_M_N(); - group_stride_a = args.C_; - group_stride_b = args.K_ * args.C_ * + NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; + group_stride_a = args.C_ * NumGroupsToMerge; + group_stride_b = args.K_ * args.C_ * NumGroupsToMerge * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end(), 1, std::multiplies()); - group_stride_c = args.K_; + group_stride_c = args.K_ * NumGroupsToMerge; // Initialize Split-N support fields for 2D convolution (NHWGC layout) // Get the actual split N from transformer @@ -213,8 +220,20 @@ struct GroupedConvFwdKernelArgs output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; - // Update GemmM to use split N (not original N) - GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; + GemmM = a_grid_desc_m_k.get_length(number<0>{}); + GemmN = b_grid_desc_n_k.get_length(number<0>{}); + GemmK = a_grid_desc_m_k.get_length(number<1>{}); + GemmBatch = integer_divide_ceil(args.G_, NumGroupsToMerge); + + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK + << ", GemmBatch: " << GemmBatch << ", N per split: " << n_per_split + << ", number of N splits: " << n_splits + << ", input_batch_stride: " << input_batch_stride + << ", output_batch_stride: " << output_batch_stride + << ", NumGroupsToMerge: " << NumGroupsToMerge << std::endl; + } } template < @@ -262,12 +281,6 @@ struct GroupedConvFwdKernelArgs k_batch = args.k_batch; - // Note: GemmM will be set after Split-N calculation - GemmN = args.K_; - GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] * - args.filter_spatial_lengths_[2]; - GemmBatch = args.G_; - in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; for(index_t d = 0; d < NumDTensor; d++) @@ -292,13 +305,14 @@ struct GroupedConvFwdKernelArgs c_grid_desc_m_n = transformer_.template MakeCDescriptor_M_N(); - group_stride_a = args.C_; - group_stride_b = args.K_ * args.C_ * + NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; + group_stride_a = args.C_ * NumGroupsToMerge; + group_stride_b = args.K_ * args.C_ * NumGroupsToMerge * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end(), 1, std::multiplies()); - group_stride_c = args.K_; + group_stride_c = args.K_ * NumGroupsToMerge; // Initialize Split-N support fields for 3D convolution (NDHWGC layout) // Get the actual split N from transformer @@ -313,11 +327,21 @@ struct GroupedConvFwdKernelArgs output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2]; - // Update GemmM to use split N (not original N) - GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] * - args.output_spatial_lengths_[2]; - } + GemmM = a_grid_desc_m_k.get_length(number<0>{}); + GemmN = b_grid_desc_n_k.get_length(number<0>{}); + GemmK = a_grid_desc_m_k.get_length(number<1>{}); + GemmBatch = integer_divide_ceil(args.G_, NumGroupsToMerge); + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK + << ", GemmBatch: " << GemmBatch << ", N per split: " << n_per_split + << ", number of N splits: " << n_splits + << ", input_batch_stride: " << input_batch_stride + << ", output_batch_stride: " << output_batch_stride + << ", NumGroupsToMerge: " << NumGroupsToMerge << std::endl; + } + } using AGridDescMK = remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeADescriptor_M_K())>; @@ -343,6 +367,7 @@ struct GroupedConvFwdKernelArgs index_t GemmN; index_t GemmK; index_t GemmBatch; + index_t NumGroupsToMerge; const void* in_ptr; const void* wei_ptr; @@ -567,13 +592,25 @@ struct GroupedConvolutionForwardKernel [[nodiscard]] CK_TILE_HOST static const std::string GetName() { + constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; // clang-format off - return concat('_', "grouped_convolution_forward", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName()); + if (NumGroupsToMerge > 1) { + return concat('_', "grouped_convolution_forward", + gemm_prec_str(), + "gemm", + GemmPipeline::GetName(), + "epilogue", + EpiloguePipeline::GetName(), + "merge", + NumGroupsToMerge); + } else { + return concat('_', "grouped_convolution_forward", + gemm_prec_str(), + "gemm", + GemmPipeline::GetName(), + "epilogue", + EpiloguePipeline::GetName()); + } // clang-format on } @@ -742,6 +779,16 @@ struct GroupedConvolutionForwardKernel return false; } + if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1) + { + const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}]; + if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0) + { + CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!"); + return false; + } + } + return true; } diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp index ae67b30e70..8bea7f653c 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp @@ -470,10 +470,10 @@ struct TransformConvFwdToGemm bool>::type = false> CK_TILE_HOST auto MakeADescriptor_M_K() const { + IndexType NStrideTensorA_ = Wi_ * G_ * C_; IndexType WiStride_ = G_ * C_; - IndexType CStrideTensorA_ = 1; - IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_; IndexType GStrideTensorA_ = C_; + IndexType CStrideTensorA_ = 1; if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) { @@ -701,11 +701,11 @@ struct TransformConvFwdToGemm CK_TILE_HOST auto MakeADescriptor_M_K() const { + IndexType NStrideTensorA_ = Hi_ * Wi_ * G_ * C_; IndexType HiStride_ = Wi_ * G_ * C_; IndexType WiStride_ = G_ * C_; - IndexType CStrideTensorA_ = 1; - IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_; IndexType GStrideTensorA_ = C_; + IndexType CStrideTensorA_ = 1; if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) { @@ -960,12 +960,12 @@ struct TransformConvFwdToGemm CK_TILE_HOST auto MakeADescriptor_M_K() const { + IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_; IndexType DiStride_ = Hi_ * Wi_ * G_ * C_; IndexType HiStride_ = Wi_ * G_ * C_; IndexType WiStride_ = G_ * C_; - IndexType CStrideTensorA_ = 1; - IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_; IndexType GStrideTensorA_ = C_; + IndexType CStrideTensorA_ = 1; if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) { @@ -1289,9 +1289,9 @@ struct TransformConvFwdToGemm bool>::type = false> CK_TILE_HOST auto MakeBDescriptor_N_K() const { - IndexType CStrideTensorB_ = 1; - IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_; IndexType GStrideTensorB_ = K_ * Z_ * Y_ * X_ * C_; + IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_; + IndexType CStrideTensorB_ = 1; if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3) { @@ -1356,10 +1356,10 @@ struct TransformConvFwdToGemm bool>::type = false> CK_TILE_HOST auto MakeCDescriptor_M_N() const { + IndexType NStrideTensorC_ = Wo_ * G_ * K_; IndexType WoStride_ = G_ * K_; - IndexType KStrideTensorC_ = 1; - IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_; IndexType GStrideTensorC_ = K_; + IndexType KStrideTensorC_ = 1; const IndexType NDoHoWo = N_ * Wo_; if constexpr(NumGroupsToMerge == 1) @@ -1417,11 +1417,11 @@ struct TransformConvFwdToGemm bool>::type = false> CK_TILE_HOST auto MakeCDescriptor_M_N() const { + IndexType NStrideTensorC_ = Ho_ * Wo_ * G_ * K_; IndexType HoStride_ = Wo_ * G_ * K_; IndexType WoStride_ = G_ * K_; - IndexType KStrideTensorC_ = 1; - IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_; IndexType GStrideTensorC_ = K_; + IndexType KStrideTensorC_ = 1; const IndexType NDoHoWo = N_ * Ho_ * Wo_; if constexpr(NumGroupsToMerge == 1) @@ -1482,12 +1482,12 @@ struct TransformConvFwdToGemm bool>::type = false> CK_TILE_HOST auto MakeCDescriptor_M_N() const { + IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_; IndexType DoStride_ = Ho_ * Wo_ * G_ * K_; IndexType HoStride_ = Wo_ * G_ * K_; IndexType WoStride_ = G_ * K_; - IndexType KStrideTensorC_ = 1; - IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_; IndexType GStrideTensorC_ = K_; + IndexType KStrideTensorC_ = 1; const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_; if constexpr(NumGroupsToMerge == 1)