From 5e0d3e77b90f0c9766cdc2463f06554d8c5a079e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Thu, 8 Jan 2026 07:59:38 +0100 Subject: [PATCH] [CK TILE] Fix grouped conv kernels splitk and double lds (#3527) [ROCm/composable_kernel commit: bc497beffb1cb1036c995f50328b0535da3af159] --- ...ouped_convolution_backward_data_kernel.hpp | 138 +++-------------- ...ped_convolution_backward_weight_kernel.hpp | 105 ++----------- .../grouped_convolution_forward_kernel.hpp | 139 +++--------------- 3 files changed, 53 insertions(+), 329 deletions(-) diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index 2e5f536ab7..a0ade4d318 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -1036,84 +1036,16 @@ struct GroupedConvolutionBackwardDataKernel } else { - auto c_block_window = MakeCBlockWindow( - c_ptr, kargs, group_id, block_idx_m, block_idx_n); + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + is_any_of::value)) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, group_id, block_idx_m, block_idx_n); - EpiloguePipeline{} - .template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); - } - } - - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @note RunGemm2LDS in with two shared memory buffers using the ping pong buffer mechanism. - * - * @param a_ptr input A pointer - * @param b_ptr input B pointer - * @param c_ptr output C pointer - * @param smem_ptr_0 The starting pointer of 1st shared memory block. - * @param smem_ptr_1 The starting pointer of 2nd shared memory block. - * @param kargs Grouped Convolution Backward Data kernel arguments - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr, - const InDataType* b_ptr, - const std::array& ds_ptr, - WeiDataType* c_ptr, - void* __restrict__ smem_ptr_0, - void* __restrict__ smem_ptr_1, - const GroupedConvBwdDataKernelArgsSpecialized& kargs, - const index_t splitted_k, - const index_t block_idx_m, - const index_t block_idx_n, - const index_t block_idx_k, - const index_t group_id) - { - // Create block windows using specialized methods - const auto& a_block_window = - MakeABlockWindow(a_ptr, kargs, group_id, block_idx_m, block_idx_k); - const auto& b_block_window = - MakeBBlockWindow(b_ptr, kargs, group_id, block_idx_n, block_idx_k); - const auto& d_block_window = - MakeDBlockWindows(ds_ptr, kargs, group_id, block_idx_m, block_idx_n); - - const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k)); - const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); - const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - - // Run GEMM cooperatively by whole workgroup. - const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, - b_block_window, - num_loop, - has_hot_loop, - tail_num, - smem_ptr_0, - smem_ptr_1); - - const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); - - // Run Epilogue Pipeline with k_batch dispatch - if(k_batch == 1) - { - auto c_block_window = MakeCBlockWindow( - c_ptr, kargs, group_id, block_idx_m, block_idx_n); - - EpiloguePipeline{} - .template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); - } - else - { - auto c_block_window = MakeCBlockWindow( - c_ptr, kargs, group_id, block_idx_m, block_idx_n); - - EpiloguePipeline{} - .template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } } } @@ -1195,46 +1127,18 @@ struct GroupedConvolutionBackwardDataKernel static_cast(kargs.in_ptr) + group_offset_c + input_batch_offset; // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; - - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(a_ptr, - b_ptr, - kargs.ds_ptr, - c_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitted_k, - i_m, - i_n, - i_k, - group_id); - } - } - else - { - if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && - is_any_of::value)) - { - RunGemm(a_ptr, - b_ptr, - kargs.ds_ptr, - c_ptr, - smem_ptr_0, - kargs, - splitted_k, - i_m, - i_n, - i_k, - group_id); - } - } + __shared__ char smem_ptr[GetSmemSize()]; + RunGemm(a_ptr, + b_ptr, + kargs.ds_ptr, + c_ptr, + smem_ptr, + kargs, + splitted_k, + i_m, + i_n, + i_k, + group_id); } }; diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 6bcd05e9ba..916f7a96ab 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -829,66 +829,14 @@ struct GroupedConvolutionBackwardWeightKernel } else { - auto c_block_window = MakeCBlockWindow( - c_ptr, kargs, block_idx_m, block_idx_n); + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + is_any_of::value)) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); - } - } - - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. - * - * @param a_ptr input A pointer - * @param b_ptr input B pointer - * @param c_ptr output C pointer - * @param smem_ptr_0 The starting pointer of 1st shared memory block. - * @param smem_ptr_1 The starting pointer of 2nd shared memory block. - * @param kargs Grouped Convolution Backward Weight kernel arguments - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr, - const InDataType* b_ptr, - const std::array& ds_ptr, - WeiDataType* c_ptr, - void* __restrict__ smem_ptr_0, - void* __restrict__ smem_ptr_1, - const GroupedConvBwdWeightKernelArgsSpecialized& kargs, - const index_t num_loop, - const index_t block_idx_m, - const index_t block_idx_n, - const index_t block_idx_k) - { - // Create block windows using helper methods - const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, block_idx_m, block_idx_k); - const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, block_idx_n, block_idx_k); - const auto& d_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); - - // Run GEMM cooperatively by whole workgroup. - const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); - - // Run Epilogue Pipeline with k_batch dispatching - if(kargs.k_batch == 1) - { - auto c_block_window = MakeCBlockWindow( - c_ptr, kargs, block_idx_m, block_idx_n); - - EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); - } - else - { -#if defined(__gfx11__) - return; -#endif - auto c_block_window = MakeCBlockWindow( - c_ptr, kargs, block_idx_m, block_idx_n); - - EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } } } @@ -949,44 +897,9 @@ struct GroupedConvolutionBackwardWeightKernel const InDataType* b_ptr = static_cast(kargs.in_ptr) + group_offset_b; WeiDataType* c_ptr = static_cast(kargs.wei_ptr) + group_offset_c; - __shared__ char smem_ptr_0[GetSmemSize()]; + __shared__ char smem_ptr[GetSmemSize()]; - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(a_ptr, - b_ptr, - kargs.ds_ptr, - c_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - num_loop, - i_m, - i_n, - i_k); - } - } - else - { - if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && - is_any_of::value)) - { - RunGemm(a_ptr, - b_ptr, - kargs.ds_ptr, - c_ptr, - smem_ptr_0, - kargs, - num_loop, - i_m, - i_n, - i_k); - } - } + RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr, kargs, num_loop, i_m, i_n, i_k); } } }; diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 1b81bce34a..4af8d8a768 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -954,80 +954,16 @@ struct GroupedConvolutionForwardKernel } else { - auto c_block_window = MakeCBlockWindow( - c_ptr, c_desc, block_idx_m, block_idx_n); + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + is_any_of::value)) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, c_desc, block_idx_m, block_idx_n); - EpiloguePipeline{elfunc} - .template operator()( - c_block_window, c_block_tile, ds_block_window, smem_ptr_0); - } - } - - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. - * - * @param a_ptr input A pointer - * @param b_ptr input B pointer - * @param ds_ptr input D tensors pointer array - * @param c_ptr output C pointer - * @param smem_ptr_0 The starting pointer of 1st shared memory block. - * @param smem_ptr_1 The starting pointer of 2nd shared memory block. - * @param a_desc Input tensor A descriptor - * @param b_desc Weight tensor B descriptor - * @param c_desc Output tensor C descriptor - * @param gemm_k The GEMM K dimension - * @param k_batch The K batch parameter for split-K - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - template - CK_TILE_DEVICE static void RunGemm2LDS(const InDataType* a_ptr, - const WeiDataType* b_ptr, - const std::array& ds_ptr, - OutDataType* c_ptr, - void* __restrict__ smem_ptr_0, - void* __restrict__ smem_ptr_1, - const ADescType& a_desc, - const BDescType& b_desc, - const CDescType& c_desc, - const index_t gemm_k, - const index_t k_batch, - const index_t block_idx_m, - const index_t block_idx_n, - const CDElementwise& elfunc) - { - // Create block windows using specialized methods - const auto& a_block_window = MakeABlockWindow(a_ptr, a_desc, block_idx_m); - const auto& b_block_window = MakeBBlockWindow(b_ptr, b_desc, block_idx_n); - const auto& ds_block_window = MakeDBlockWindows(ds_ptr, c_desc, block_idx_m, block_idx_n); - - const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k)); - - // Run GEMM cooperatively by whole workgroup. - const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); - - // Run Epilogue Pipeline with k_batch dispatching - if(k_batch == 1) - { - auto c_block_window = MakeCBlockWindow( - c_ptr, c_desc, block_idx_m, block_idx_n); - - EpiloguePipeline{elfunc} - .template operator()( - c_block_window, c_block_tile, ds_block_window, smem_ptr_0); - } - else - { - auto c_block_window = MakeCBlockWindow( - c_ptr, c_desc, block_idx_m, block_idx_n); - - EpiloguePipeline{elfunc} - .template operator()( - c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + EpiloguePipeline{elfunc} + .template operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } } } @@ -1177,50 +1113,21 @@ struct GroupedConvolutionForwardKernel const auto& c_desc = kargs.c_grid_desc_m_n; // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; + __shared__ char smem_ptr[GetSmemSize()]; - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(a_ptr, - b_ptr, - ds_ptr_with_offsets, - c_ptr, - smem_ptr_0, - smem_ptr_1, - a_desc, - b_desc, - c_desc, - kargs.GemmK, - kargs.k_batch, - i_m, - i_n, - kargs.elfunc); - } - } - else - { - if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && - is_any_of::value)) - { - RunGemm(a_ptr, - b_ptr, - ds_ptr_with_offsets, - c_ptr, - smem_ptr_0, - a_desc, - b_desc, - c_desc, - kargs.GemmK, - kargs.k_batch, - i_m, - i_n, - kargs.elfunc); - } - } + RunGemm(a_ptr, + b_ptr, + ds_ptr_with_offsets, + c_ptr, + smem_ptr, + a_desc, + b_desc, + c_desc, + kargs.GemmK, + kargs.k_batch, + i_m, + i_n, + kargs.elfunc); } } };