From 536315849b1c08732a42a3e460afe51cdabfc771 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Tue, 23 Dec 2025 10:03:42 +0100 Subject: [PATCH] [CK_TILE] Add splitk support to ck tile conv bwd data (#3353) * add splitk support to ck tile conv bwd data * add reviewers suggestions * minor fix * removed splitkbatchoffset struct [ROCm/composable_kernel commit: ead81d1b0bba57b86ac28f3e2994dc97279f8eb3] --- ...ouped_convolution_backward_data_kernel.hpp | 75 +++++++++++++------ ...ped_convolution_backward_weight_kernel.hpp | 37 ++------- 2 files changed, 57 insertions(+), 55 deletions(-) diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index 6e1ac39509..ad445e17a7 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -542,9 +542,6 @@ struct GroupedConvolutionBackwardDataKernel static constexpr index_t MaxGroupedGemmGroupsNum = GroupedConvBwdDataKernelArgsSpecialized::MaxGroupedGemmGroupsNum; - // TODO: Enable this - static constexpr bool IsSplitKSupported = false; - static constexpr auto I0 = number<0>(); static constexpr auto I1 = number<1>(); static constexpr auto I2 = number<2>(); @@ -623,9 +620,8 @@ struct GroupedConvolutionBackwardDataKernel CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized& kargs) { - if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 && - is_any_of::value) || - !IsSplitKSupported) + if constexpr(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + is_any_of::value) { if(kargs.k_batch != 1) { @@ -772,8 +768,8 @@ struct GroupedConvolutionBackwardDataKernel }(); const auto& c_tensor_view = [&]() { - return make_tensor_view(c_ptr, - kargs.c_grid_descs_m_n[group_id]); + return make_tensor_view( + c_ptr, kargs.c_grid_descs_m_n[group_id]); }(); const auto& ds_tensor_view = generate_tuple( @@ -837,7 +833,7 @@ struct GroupedConvolutionBackwardDataKernel CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n, - const index_t i_k = 0) + const index_t i_k) { const auto& a_pad_view = views.at(I0); const auto& b_pad_view = views.at(I1); @@ -893,20 +889,24 @@ struct GroupedConvolutionBackwardDataKernel WeiDataType* c_ptr, void* smem_ptr_0, const GroupedConvBwdDataKernelArgsSpecialized& kargs, + const index_t splitted_k, const index_t block_idx_m, const index_t block_idx_n, + const index_t block_idx_k, const index_t group_id) { // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum( - gemm_pad_views.at(I0).get_tensor_descriptor().get_length(I1))); + const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k)); + const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); + + auto gemm_tile_windows = + MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k); // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); @@ -914,7 +914,7 @@ struct GroupedConvolutionBackwardDataKernel const auto& d_block_window = gemm_tile_windows.at(I2); const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, smem_ptr_0); + a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I3); @@ -945,8 +945,10 @@ struct GroupedConvolutionBackwardDataKernel void* __restrict__ smem_ptr_0, void* __restrict__ smem_ptr_1, const GroupedConvBwdDataKernelArgsSpecialized& kargs, + const index_t splitted_k, const index_t block_idx_m, const index_t block_idx_n, + const index_t block_idx_k, const index_t group_id) { // Create Gemm tensor views, pad views and tile windows @@ -954,18 +956,25 @@ struct GroupedConvolutionBackwardDataKernel MakeGemmTensorViews( a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const index_t num_loop = amd_wave_read_first_lane( - TilePartitioner::GetLoopNum(gemm_tile_windows.at(I0).get_length(I1))); + const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k)); + const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); + auto gemm_tile_windows = + MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k); // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); const auto& b_block_window = gemm_tile_windows.at(I1); const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); + const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, + b_block_window, + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0, + smem_ptr_1); // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I3); @@ -1031,9 +1040,17 @@ struct GroupedConvolutionBackwardDataKernel static_cast(kargs.input_batch_stride); // SplitK - // TODO: Implement SplitK support - // const index_t split_k_idx = - // __builtin_amdgcn_readfirstlane(blockIdZ - split_n_idx * kargs.k_batch); + const index_t split_k_idx = + __builtin_amdgcn_readfirstlane(blockIdZ - split_n_idx * kargs.k_batch); + + const index_t gemm_k = kargs.a_grid_descs_m_k[group_id].get_length(I1); + + constexpr auto K1 = TilePartitioner::KPerBlock; + const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1); + const index_t KRead = amd_wave_read_first_lane((gemm_k + K_t - 1) / K_t * K1); + + const index_t i_k = amd_wave_read_first_lane(split_k_idx * KRead); + const index_t splitted_k = amd_wave_read_first_lane(KRead); // options // conv_bwd_data = Out * Weight = In @@ -1060,8 +1077,10 @@ struct GroupedConvolutionBackwardDataKernel smem_ptr_0, smem_ptr_1, kargs, + splitted_k, i_m, i_n, + i_k, group_id); } } @@ -1071,7 +1090,17 @@ struct GroupedConvolutionBackwardDataKernel GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) { - RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n, group_id); + RunGemm(a_ptr, + b_ptr, + kargs.ds_ptr, + c_ptr, + smem_ptr_0, + kargs, + splitted_k, + i_m, + i_n, + i_k, + group_id); } } } diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 1004ed81b1..6034dfc3de 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -505,33 +505,6 @@ struct GroupedConvolutionBackwardWeightKernel return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } - struct SplitKBatchOffset - { - __device__ SplitKBatchOffset(const GroupedConvBwdWeightKernelArgsSpecialized& kargs, - const std::size_t k_id = blockIdx.z) - { - constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}); - const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1); - const index_t KRead = amd_wave_read_first_lane((kargs.GemmK + K_t - 1) / K_t * K1); - - a_k_split_offset = amd_wave_read_first_lane(k_id * KRead); - b_k_split_offset = amd_wave_read_first_lane(k_id * KRead); - - if(k_id < static_cast(kargs.k_batch - 1)) - { - splitted_k = amd_wave_read_first_lane(KRead); - } - else - { - splitted_k = amd_wave_read_first_lane(kargs.GemmK - KRead * (kargs.k_batch - 1)); - } - } - - index_t a_k_split_offset; - index_t b_k_split_offset; - index_t splitted_k; - }; - CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs) { @@ -763,12 +736,12 @@ struct GroupedConvolutionBackwardWeightKernel } template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views, const index_t k_batch) + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) { const auto& a_pad_view = [&]() { const auto& a_tensor_view = views.at(I0); return pad_tensor_view(a_tensor_view, - make_tuple(number{} * k_batch, + make_tuple(number{}, number{}), sequence{}); }(); @@ -776,7 +749,7 @@ struct GroupedConvolutionBackwardWeightKernel const auto& b_pad_view = [&]() { const auto& b_tensor_view = views.at(I1); return pad_tensor_view(b_tensor_view, - make_tuple(number{} * k_batch, + make_tuple(number{}, number{}), sequence{}); }(); @@ -882,7 +855,7 @@ struct GroupedConvolutionBackwardWeightKernel MakeGemmTensorViews( a_ptr, b_ptr, ds_ptr, c_ptr, kargs); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch); + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k); @@ -932,7 +905,7 @@ struct GroupedConvolutionBackwardWeightKernel const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( a_ptr, b_ptr, ds_ptr, c_ptr, kargs); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch); + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);