From 2e1831f8fd134f59bdd7790079adbc53871fcbd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 31 Oct 2025 14:11:54 +0100 Subject: [PATCH] [CK TILE] Clear output buffers for grouped conv bwd (#3127) [ROCm/composable_kernel commit: c2d79314469f569c13c205ff5383f284c90d7445] --- ...rouped_convolution_backward_data_invoker.hpp | 11 +++++++++-- ...uped_convolution_backward_weight_invoker.hpp | 17 +++++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp index d8a6564f46..f6d20c3d3a 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp @@ -170,8 +170,15 @@ struct GroupedConvolutionBackwardDataInvoker << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; } - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + auto preprocess = [&]() { + ck_tile::hip_check_error(hipMemsetAsync( + kargs.in_ptr, 0, args.template GetInputByte(), s.stream_id_)); + }; + + ave_time = ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp index becc11ff54..0c00bb78e1 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp @@ -171,8 +171,21 @@ struct GroupedConvolutionBackwardWeightInvoker << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; } - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + auto preprocess = [&]() { + if(args.k_batch > 1) + { + ck_tile::hip_check_error( + hipMemsetAsync(kargs.wei_ptr, + 0, + args.template GetWeightByte(), + s.stream_id_)); + } + }; + + ave_time = ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; };