From 679699f32ad2a018184443647a33caaac4c2d348 Mon Sep 17 00:00:00 2001 From: Johannes Graner Date: Mon, 24 Nov 2025 07:36:26 +0100 Subject: [PATCH] [CK Tile] Fix example for conv fwd + bias + clamp (#3235) * Fix clamp not being applied correctly * Apply group offsets to D tensors --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> [ROCm/composable_kernel commit: 096f0a3b23a49ffaef1e2dbed74bf366e36ad15c] --- ...ped_convolution_fwd_bias_clamp_example.inc | 27 ++++++-------- .../grouped_convolution_forward_kernel.hpp | 36 +++++++++++++------ 2 files changed, 36 insertions(+), 27 deletions(-) diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_bias_clamp_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_bias_clamp_example.inc index fe3c1791d6..3065bc1e94 100644 --- a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_bias_clamp_example.inc +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_bias_clamp_example.inc @@ -184,11 +184,6 @@ int run_grouped_conv_fwd_bias_clamp_example_with_layouts( if(arg_parser.get_int("v") == 1) { - // FIXME: Address this issue - if(arg_parser.get_int("g") > 1 && init_method == 0) - std::cerr << "Adding different bias to different groups yield incorrect results" - << std::endl; - ck_tile::HostTensor output_host_ref(out_g_n_k_wos_desc); output_host_ref.SetZero(); @@ -250,29 +245,27 @@ template {}, - // GemmWarpConfig, - // Invoker, - // InPrecType, - // WeiPrecType, - // OutPrecType>( - // argc, argv, NWGC{}, GKXC{}, NWGK{}); + return run_grouped_conv_fwd_bias_clamp_example_with_layouts{}, + GemmWarpConfig, + Invoker, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NWGC{}, GKXC{}, NWGK{}); } else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") { diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 4eccd1eebb..a07ba1b05d 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -5,9 +5,11 @@ #include #include +#include #include "ck_tile/core.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp" +#include "ck_tile/core/utility/functional.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/host/concat.hpp" #include "ck_tile/core/utility/env.hpp" @@ -884,7 +886,8 @@ struct GroupedConvolutionForwardKernel const CDescType& c_desc, const index_t gemm_k, const index_t block_idx_m, - const index_t block_idx_n) + const index_t block_idx_n, + const CDElementwise& elfunc) { // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = @@ -907,8 +910,9 @@ struct GroupedConvolutionForwardKernel // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{elfunc} + .template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); } /** @@ -942,7 +946,8 @@ struct GroupedConvolutionForwardKernel const CDescType& c_desc, const index_t gemm_k, const index_t block_idx_m, - const index_t block_idx_n) + const index_t block_idx_n, + const CDElementwise& elfunc) { // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = @@ -964,8 +969,9 @@ struct GroupedConvolutionForwardKernel // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{elfunc} + .template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); } CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const @@ -998,6 +1004,14 @@ struct GroupedConvolutionForwardKernel OutDataType* base_c_ptr = static_cast(kargs.out_ptr) + group_offset_c + output_batch_offset; + // Apply group offsets to D tensors + std::array ds_ptr_with_offsets; + static_for<0, NumDTensor, 1>{}([&](auto d) { + using DType = std::tuple_element_t; + ds_ptr_with_offsets[d] = + static_cast(kargs.ds_ptr[d]) + group_offset_c + output_batch_offset; + }); + // ===================================================================== // Split-image: Map local block to global tile index (if enabled) // ===================================================================== @@ -1085,7 +1099,7 @@ struct GroupedConvolutionForwardKernel { RunGemm2LDS(a_ptr, b_ptr, - kargs.ds_ptr, + ds_ptr_with_offsets, c_ptr, smem_ptr_0, smem_ptr_1, @@ -1094,7 +1108,8 @@ struct GroupedConvolutionForwardKernel c_desc, kargs.GemmK, i_m, - i_n); + i_n, + kargs.elfunc); } } else @@ -1105,7 +1120,7 @@ struct GroupedConvolutionForwardKernel { RunGemm(a_ptr, b_ptr, - kargs.ds_ptr, + ds_ptr_with_offsets, c_ptr, smem_ptr_0, a_desc, @@ -1113,7 +1128,8 @@ struct GroupedConvolutionForwardKernel c_desc, kargs.GemmK, i_m, - i_n); + i_n, + kargs.elfunc); } } }