From 77407b3d261856cc7e8496dbe3a0c071f2fb8f29 Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Fri, 28 Nov 2025 12:33:53 +0000 Subject: [PATCH] [CK_TILE] Fix Quant GEMM build (#3320) * Fix build * Fix ck_tile example 38 & 40 --------- Co-authored-by: Yi DING [ROCm/composable_kernel commit: f981554c39eafbf993e05c832cb86b3aaf474571] --- .../38_block_scale_gemm/gemm_utils.hpp | 34 ------------------- .../run_gemm_quant_example.inc | 7 ++-- .../40_streamk_gemm/run_gemm_example.inc | 3 +- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 2 ++ .../kernel/grouped_gemm_quant_kernel.hpp | 1 + .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 1 + .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 4 +-- ...p_bquant_pipeline_ag_bg_cr_base_policy.hpp | 1 + .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 1 + 9 files changed, 13 insertions(+), 41 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index bd9f93ce23..81032d6452 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -280,37 +280,3 @@ struct GemmQuantTypeConfig using AccDataType = float; using CDataType = CDataType_; }; - -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("m", "3840", "m dimension") - .insert("n", "4096", "n dimension") - .insert("k", "2048", "k dimension") - .insert("a_layout", "R", "A tensor data layout - Row by default") - .insert("b_layout", "C", "B tensor data layout - Column by default") - .insert("bq_layout", "C", "Bq tensor data layout - Column by default") - .insert("c_layout", "R", "C tensor data layout - Row by default") - .insert("stride_a", "0", "Tensor A stride") - .insert("stride_q", "0", "Tensor AQ stride") - .insert("stride_b", "0", "Tensor B stride") - .insert("stride_c", "0", "Tensor C stride") - .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") - .insert("prec", - "fp8", - "data type. For AQuant: fp8/bf8/i4fp8/i4bf8, For Bquant: fp8/bf8/fp8i4/bf8i4") - .insert("warmup", "50", "number of iterations before benchmark the kernel") - .insert("repeat", "1000", "number of iterations to benchmark the kernel") - .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") - .insert("split_k", "1", "splitK value") - .insert("init", "0", "0:random, 1:linear, 2:constant(1)") - .insert("flush_cache", "true", "flush cache before running the kernel, defaults to true") - .insert("rotating_count", "1000", "rotating count, defaults to 1") - .insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol") - .insert("group_size", - "1x1x128", - "Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index c9a57e7754..44d0736ad3 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -15,6 +15,7 @@ #include "ck_tile/host.hpp" #include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/host/tensor_shuffle_utils.hpp" +#include "ck_tile/ops/gemm_quant.hpp" #include "gemm_utils.hpp" template , std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true, - ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3, + ck_tile::BaseGemmPipelineAgBgCrCompV3, std::conditional_t, - ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3>>>; + ck_tile::BaseGemmPipelineAgBgCrMem, + ck_tile::BaseGemmPipelineAgBgCrCompV3>>>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; diff --git a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc index ebb5140e50..d18ac2e68a 100644 --- a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc +++ b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc @@ -81,8 +81,7 @@ invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, K, stride_A, stride_B, - stride_C, - reduction_strategy}; + stride_C}; std::tuple ave_time_and_batch; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index a0cfd4bf53..c7c161e710 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -4,6 +4,8 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index c80d5d5267..caa6aad363 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -7,6 +7,7 @@ #include "ck_tile/core/utility/literals.hpp" #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/host/stream_utils.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp" diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 8d76ab934b..f3c8b7a1a3 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -8,6 +8,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp" diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 8f4d4e0460..4883a30f57 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -272,7 +272,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( - Policy::template make_shuffled_2d_static_tile_distribution()); + Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); } @@ -284,7 +284,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( - Policy::template make_shuffled_2d_static_tile_distribution()); + Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp index e90fa42407..28a06f8b3d 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp @@ -3,6 +3,7 @@ #pragma once +#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp" diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index d7129268c5..59a5b0df4e 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -9,6 +9,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/host/concat.hpp"