From abccb649d12e21392f104e22e1cf6f77e9c0d90e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <188998872+vpietila-amd@users.noreply.github.com> Date: Wed, 29 Oct 2025 16:49:28 +0200 Subject: [PATCH] [CK_Tile] Merge multiple convolution groups into a single GEMM batch (#2986) * Fix compilation of the grouped conv examples. * Fix grouped conv bwd weight example output in CK Tile. * Add number of groups to merge to ck tile grouped gemm example. * Initial set of tests for TransformConvBwdWeightToGemm. * Added unit tests for TransformConvBwdWeightToGemm conv groups are merged. * WIP: Tensor transformations. * Add unit tests for coordinate transforms. * Fully working conv group merging for TransformConvBwdWeightToGemm. * WIP: Merged conv groups offset calculation. * Adde unit tests for tensor view. * WIP: Merged conv groups epilogue. * Enable running multiple conv groups per batch. * Add tests for tile_distribution_encoding. * Change example to match optimally depthwise convolution with merged groups. * Add more tests for tensor view. * Integration test for reading diagonal blocks from grouped distributed tensor. * Improved integration test. * Improve test for accessing diagonal blocks. * Added integration test for cshuffle epilogue LDS tile distribution. * Add more logging. * Increase the max number of reported errors. * WIP: merged conv groups GEMM epilogue changes. * LDS to global memory copy. * Fix tile window size for c block. * Integration test for CShuffle epilogue. * Improved CShuffle test. * WIP: Separate epilogue for merged conv groups. * Tile example parameters changes to match depthwise conv. * Offset fixes. * Epilogue fixes. * Working baseline for depthwise covolution with merged conv groups. * Fix build. * Initial unit tests for tensor descriptor. * Add one more unit test for tensor view. * WIP: LDS to global mem transfer using CK tile tensor descriptor and tile distribution encoding. * Fully functional LDS to global mem transfer using tensor descriptor and tile distribution encoding. * Add more comments, disable debug code. * Remove debug and other dead code. * Code clean-up for bwd tensor transformations. * Enable running multiple GEMM batches of merged conv groups. * Add compile check for assumed row-mjor layout. * Fix strides in 1D conv to gemm transformation. * WIP: Simplify conv to gemm transformations and handle K > 1 and C > 1 cases. * Fix case k > 1 and c=1. * Remove debug code. * Make MPerGroup and NPerGroup template parameters. * Add additional check for non-supported c > 1 case. * WIP: Put back the generic tensor descriptors for convolutions. * Fix tensor descriptors. * Remove the obsolete template parameters. * Add more instances. * Fix bugs in merged conv groups tensor descriptors. * Fix tensor descriptors for merged conv groups when K > 1. * Remove debug output. * Remove dead code. * Fix merge conflicts. * Code clean-up. * Remove unused code. * Run clang-formatting. * Remove debug prints and obsolete tests. * Check that number of convolution groups is multiple of merged groups. * Fix build after removing obsolete functionality. * Remove obsolete enumeration. * Fix new unit projects. * Remove unnecessary includes. * Fix passing the number of merged groups. * Remove unrelated tests. * Fix IsSupportedArgument for bwd weight conv kernel. * Fix clang formatting. * Fix the bwd weight conv to gemm mapping for num merged groups > 1. * GEMM config for conv group merging. * Fix clang-formatting. * Remove obsolete comment. * Fix typos in comment strings. * Increase the max number of reported errors when testing against reference implementation. * Rename gemm_config to conv_config. * Rename GemmConfig to ConvConfig and move NumGroupsToMerge into ConvConfig. * Change num_groups_to_merge to a boolean flag in the ck tile grouped conv example. * Run clang-format. * Add number of merged groups into kernel name string. * Remove group merging flag from CK Tile grouped conv example. [ROCm/composable_kernel commit: 121bf0e1f3325614560268fe9a4af6bbed38712a] --- .../{gemm_configs.hpp => conv_configs.hpp} | 51 +- .../grouped_convolution_backward_data.cpp | 4 +- .../grouped_convolution_backward_weight.cpp | 10 +- ...ed_convolution_backward_weight_invoker.hpp | 61 +- ..._convolution_backward_weight_two_stage.cpp | 12 +- ...tion_backward_weight_two_stage_invoker.hpp | 50 +- .../grouped_convolution_forward.cpp | 4 +- .../grouped_convolution_forward_invoker.hpp | 8 +- .../grouped_convolution_utils.hpp | 15 +- ...grouped_convolution_bwd_weight_example.inc | 16 +- .../algorithm/static_encoding_pattern.hpp | 5 +- include/ck_tile/host/check_err.hpp | 2 +- ...ped_convolution_backward_weight_kernel.hpp | 114 ++- .../grouped_convolution_forward_kernel.hpp | 1 + .../utils/grouped_convolution_utils.hpp | 11 +- .../transform_conv_bwd_weight_to_gemm.hpp | 658 ++++++++++++++---- .../utils/transform_conv_fwd_to_gemm.hpp | 2 +- 17 files changed, 755 insertions(+), 269 deletions(-) rename example/ck_tile/20_grouped_convolution/{gemm_configs.hpp => conv_configs.hpp} (85%) diff --git a/example/ck_tile/20_grouped_convolution/gemm_configs.hpp b/example/ck_tile/20_grouped_convolution/conv_configs.hpp similarity index 85% rename from example/ck_tile/20_grouped_convolution/gemm_configs.hpp rename to example/ck_tile/20_grouped_convolution/conv_configs.hpp index 77e1c3af1a..1be6080383 100644 --- a/example/ck_tile/20_grouped_convolution/gemm_configs.hpp +++ b/example/ck_tile/20_grouped_convolution/conv_configs.hpp @@ -17,7 +17,7 @@ #define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_COMPUTE_V5 4 -struct GemmConfigBase +struct ConvConfigBase { static constexpr bool kPadM = true; static constexpr bool kPadN = true; @@ -29,6 +29,10 @@ struct GemmConfigBase static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; + static constexpr ck_tile::index_t VectorSizeA = 4; + static constexpr ck_tile::index_t VectorSizeB = 8; + static constexpr ck_tile::index_t VectorSizeC = 8; + static constexpr int kBlockPerCu = 1; static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; @@ -37,10 +41,12 @@ struct GemmConfigBase static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; static constexpr bool TiledMMAPermuteN = false; + + static constexpr ck_tile::index_t NumGroupsToMerge = 1; }; template -struct GemmConfigMemoryInterwave : public GemmConfigBase +struct ConvConfigMemoryInterwave : public ConvConfigBase { // Memory friendly for Interwave scheduler static constexpr ck_tile::index_t M_Tile = 128; @@ -61,7 +67,7 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase }; template -struct GemmConfigMemoryIntrawave : public GemmConfigBase +struct ConvConfigMemoryIntrawave : public ConvConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 32; @@ -80,7 +86,7 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase }; template -struct GemmConfigComputeV3 : public GemmConfigBase +struct ConvConfigComputeV3 : public ConvConfigBase { // Compute V3 only support Intrawave scheduler static constexpr ck_tile::index_t M_Tile = 16; @@ -100,7 +106,7 @@ struct GemmConfigComputeV3 : public GemmConfigBase }; template -struct GemmConfigComputeV3_1 : public GemmConfigBase +struct ConvConfigComputeV3_1 : public ConvConfigBase { static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 256; @@ -119,7 +125,7 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase }; template -struct GemmConfigComputeV3_2 : public GemmConfigBase +struct ConvConfigComputeV3_2 : public ConvConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -140,7 +146,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase }; template -struct GemmConfigComputeV3_WMMA : public GemmConfigBase +struct ConvConfigComputeV3_WMMA : public ConvConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -161,7 +167,7 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase }; template -struct GemmConfigComputeV4 : public GemmConfigBase +struct ConvConfigComputeV4 : public ConvConfigBase { // Compute V4 only support Intrawave scheduler // Using the ping pong reader in the lds level @@ -182,7 +188,7 @@ struct GemmConfigComputeV4 : public GemmConfigBase }; template -struct GemmConfigComputeV4_1 : public GemmConfigBase +struct ConvConfigComputeV4_1 : public ConvConfigBase { static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 256; @@ -201,7 +207,7 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase }; template -struct GemmConfigComputeV5 : public GemmConfigBase +struct ConvConfigComputeV5 : public ConvConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -220,6 +226,31 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; }; +template +struct ConvConfigComputeV3_merged_groups : public ConvConfigBase +{ + static constexpr ck_tile::index_t VectorSizeA = 4; + static constexpr ck_tile::index_t VectorSizeB = 8; + static constexpr ck_tile::index_t VectorSizeC = 8; + + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 32; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr ck_tile::index_t NumGroupsToMerge = 2; +}; + template struct ConvTypeConfig; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp index 6f3bedc32a..ad593b1418 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp @@ -51,8 +51,8 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[]) int main(int argc, char* argv[]) { #if CK_TILE_USE_WMMA - return !run_grouped_conv_bwd_data_example(argc, argv); + return !run_grouped_conv_bwd_data_example(argc, argv); #else - return !run_grouped_conv_bwd_data_example(argc, argv); + return !run_grouped_conv_bwd_data_example(argc, argv); #endif } diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp index 159d19fdcd..695adf01bc 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp @@ -14,7 +14,7 @@ #include "grouped_convolution_backward_weight_invoker.hpp" #include "run_grouped_convolution_bwd_weight_example.inc" -template