From bafd8577ffdbb39f8582353e8a275b238c80a5e0 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Wed, 27 Aug 2025 18:19:30 +0000 Subject: [PATCH] Merge commit 'cd53e2e57ed9106b898defc5f610b167370f028f' into develop --- .../ck_tile/18_flatmm/run_flatmm_example.inc | 4 +- ...n_grouped_convolution_bwd_data_example.inc | 62 +++++++++---------- .../kernel/gemm_aquant_kernel.hpp | 18 +++--- 3 files changed, 41 insertions(+), 43 deletions(-) diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index 013db6715d..ff1a239cba 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -40,8 +40,8 @@ template auto shuffle_b(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; int divisor = ck_tile::is_wave32() ? (FlatmmConfig::N_Warp_Tile == 32 ? 1 : 2) : (FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4); diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc index 3e1c13c833..d1cf4fade7 100644 --- a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc @@ -11,17 +11,17 @@ template float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args, - int n_warmup, - int n_repeat) + int n_warmup, + int n_repeat) { float ave_time = grouped_conv_bwd_data( + InDataType, + WeiDataType, + AccDataType, + OutDataType, + InLayout, + WeiLayout, + OutLayout>( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = args.GetFlops(); @@ -124,11 +124,11 @@ int run_grouped_conv_bwd_data_example_with_layouts( output_dev_buf.ToDevice(output.data()); ck_tile::GroupedConvBwdDataHostArgs args(conv_param, - input_dev_buf.GetDeviceBuffer(), - weight_dev_buf.GetDeviceBuffer(), - {}, - output_dev_buf.GetDeviceBuffer(), - kbatch); + input_dev_buf.GetDeviceBuffer(), + weight_dev_buf.GetDeviceBuffer(), + {}, + output_dev_buf.GetDeviceBuffer(), + kbatch); std::cout << "Run Grouped Conv Bwd Data kernel" << std::endl; std::cout << "input: " << input.mDesc << std::endl; @@ -136,13 +136,13 @@ int run_grouped_conv_bwd_data_example_with_layouts( std::cout << "output: " << output.mDesc << std::endl; invoke_grouped_conv_bwd_data(args, n_warmup, n_repeat); + InDataType, + WeiDataType, + AccDataType, + OutDataType, + InLayout, + WeiLayout, + OutLayout>(args, n_warmup, n_repeat); input_dev_buf.FromDevice(input.data()); bool pass = true; @@ -152,17 +152,15 @@ int run_grouped_conv_bwd_data_example_with_layouts( ck_tile::HostTensor input_host_ref(in_g_n_c_wis_desc); input_host_ref.SetZero(); - ck_tile:: - reference_grouped_conv_bwd_data( - input_host_ref, - weight, - output, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_); - const ck_tile::index_t GemmK = - weight.get_element_size() / (conv_param.G_ * conv_param.K_); + ck_tile::reference_grouped_conv_bwd_data( + input_host_ref, + weight, + output, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_); + const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_); const float max_accumulated_value = *std::max_element(input_host_ref.mData.begin(), input_host_ref.mData.end()); const auto rtol_atol = diff --git a/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp b/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp index 49fbbfbc71..69acb668ed 100644 --- a/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp +++ b/include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp @@ -99,15 +99,15 @@ struct AQuantGemmKernelArgs template struct AQuantGemmKernel { - using TilePartitioner = remove_cvref_t; - using GemmPipeline = remove_cvref_t; - using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using AQLayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; - static constexpr bool PreshuffleQuant = GemmPipeline::PreshuffleQuant; + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using AQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t kBlockSize = GemmPipeline::BlockSize; + static constexpr bool PreshuffleQuant = GemmPipeline::PreshuffleQuant; using ADataType = remove_cvref_t; using AQDataType = remove_cvref_t;