diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 1dd467f1c8..37ed8ce49a 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -418,7 +418,8 @@ struct UniversalGemmKernel } } - if(kargs.K < GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch) + if(integer_divide_ceil(kargs.K, GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})) < + kargs.k_batch) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 39c7ba1370..5df84be0c9 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -574,7 +574,9 @@ struct GroupedConvolutionBackwardWeightKernel } } - if(kargs.GemmK < TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch) + if(integer_divide_ceil(kargs.GemmK, + TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})) < + kargs.k_batch) { LogInfo("KBatch is too large, part of GPU wouldn't be utilized! GemmK: ", kargs.GemmK, diff --git a/profiler/include/profiler/grouped_convolution_backward_weight_tile_algs.hpp b/profiler/include/profiler/grouped_convolution_backward_weight_tile_algs.hpp index f69c5bb7a1..fb51adb4a7 100644 --- a/profiler/include/profiler/grouped_convolution_backward_weight_tile_algs.hpp +++ b/profiler/include/profiler/grouped_convolution_backward_weight_tile_algs.hpp @@ -178,11 +178,11 @@ run_grouped_conv_backward_weight_tile_algs(const ckt::Args& args, }); const bool valid = report.get_errors().empty(); + best_avg_time = std::min(best_avg_time, avg_time); + best_op_name = best_avg_time < avg_time ? best_op_name : op_name; + best_split_k = best_avg_time < avg_time ? best_split_k : k_batch; if(valid) { - best_avg_time = std::min(best_avg_time, avg_time); - best_op_name = best_avg_time < avg_time ? best_op_name : op_name; - best_split_k = best_avg_time < avg_time ? best_split_k : k_batch; std::cout << "[Valid] Perf: " << std::setw(10) << avg_time << " ms," << " " << op_name << ", SplitK " << k_batch << std::endl; } diff --git a/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp b/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp index 237641a000..4ea3479db0 100644 --- a/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp +++ b/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp @@ -219,12 +219,12 @@ TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, K0KBatchLimitation) tensor_layout::convolution::NHWGK>::type; // k_batch = 128 should pass - auto host_args_kbatch_6 = create_2d_host_args(6); + auto host_args_kbatch_6 = create_2d_host_args(7); auto kargs_6 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_6); EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_6)); // k_batch = 129 should fail for half_t output - auto host_args_kbatch_7 = create_2d_host_args(7); + auto host_args_kbatch_7 = create_2d_host_args(8); auto kargs_7 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_7); EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_7)); }