From ce4525e82bebae411a5047df489ddb6de24b474e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Sat, 21 Mar 2026 23:55:24 +0100 Subject: [PATCH] [CK][CK Tile] Fix kbatch check in grouped conv and gemm kernels (#5555) ## Motivation Fix kbatch check in grouped conv and gemm kernels, allow tails for kbatch. ## Technical Details Round up K / Kperxdl and divide it by Kbatch to allow tail for K. ## Test Plan test_grouped_convnd_bwd_weight_tile ## Test Result passed locally ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp | 3 ++- .../kernel/grouped_convolution_backward_weight_kernel.hpp | 4 +++- .../grouped_convolution_backward_weight_tile_algs.hpp | 6 +++--- .../grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp | 4 ++-- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 1dd467f1c8..37ed8ce49a 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -418,7 +418,8 @@ struct UniversalGemmKernel } } - if(kargs.K < GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch) + if(integer_divide_ceil(kargs.K, GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})) < + kargs.k_batch) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 39c7ba1370..5df84be0c9 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -574,7 +574,9 @@ struct GroupedConvolutionBackwardWeightKernel } } - if(kargs.GemmK < TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch) + if(integer_divide_ceil(kargs.GemmK, + TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})) < + kargs.k_batch) { LogInfo("KBatch is too large, part of GPU wouldn't be utilized! GemmK: ", kargs.GemmK, diff --git a/profiler/include/profiler/grouped_convolution_backward_weight_tile_algs.hpp b/profiler/include/profiler/grouped_convolution_backward_weight_tile_algs.hpp index f69c5bb7a1..fb51adb4a7 100644 --- a/profiler/include/profiler/grouped_convolution_backward_weight_tile_algs.hpp +++ b/profiler/include/profiler/grouped_convolution_backward_weight_tile_algs.hpp @@ -178,11 +178,11 @@ run_grouped_conv_backward_weight_tile_algs(const ckt::Args& args, }); const bool valid = report.get_errors().empty(); + best_avg_time = std::min(best_avg_time, avg_time); + best_op_name = best_avg_time < avg_time ? best_op_name : op_name; + best_split_k = best_avg_time < avg_time ? best_split_k : k_batch; if(valid) { - best_avg_time = std::min(best_avg_time, avg_time); - best_op_name = best_avg_time < avg_time ? best_op_name : op_name; - best_split_k = best_avg_time < avg_time ? best_split_k : k_batch; std::cout << "[Valid] Perf: " << std::setw(10) << avg_time << " ms," << " " << op_name << ", SplitK " << k_batch << std::endl; } diff --git a/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp b/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp index 237641a000..4ea3479db0 100644 --- a/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp +++ b/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp @@ -219,12 +219,12 @@ TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, K0KBatchLimitation) tensor_layout::convolution::NHWGK>::type; // k_batch = 128 should pass - auto host_args_kbatch_6 = create_2d_host_args(6); + auto host_args_kbatch_6 = create_2d_host_args(7); auto kargs_6 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_6); EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_6)); // k_batch = 129 should fail for half_t output - auto host_args_kbatch_7 = create_2d_host_args(7); + auto host_args_kbatch_7 = create_2d_host_args(8); auto kargs_7 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_7); EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_7)); }