diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index 59ff086dca..fd02e90594 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -29,6 +29,7 @@ template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, @@ -75,7 +76,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, AccDataType, GemmShape, GemmUniversalTraits, - 128>, // QuantGroupSize + QuantGroupSize>, // QuantGroupSize ck_tile::GemmRowColTensorQuantPipelineProblem float invoke_gemm(int n_warmup, @@ -104,6 +105,7 @@ float invoke_gemm(int n_warmup, BQDataType, AccDataType, CDataType, + QuantGroupSize, QuantMode>(stream, group_count, kargs_ptr); std::string op_name = "Quant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")"; @@ -134,6 +136,7 @@ template (group_count)) && ...); }; - const int group_count = arg_parser.get_int("group_count"); - const int repeat = arg_parser.get_int("repeat"); - const int warmup = arg_parser.get_int("warmup"); - const int kbatch = arg_parser.get_int("kbatch"); - const int init_method = arg_parser.get_int("init"); - bool validate = arg_parser.get_bool("validate"); - const ck_tile::index_t QuantGroupSize = 128; + const int group_count = arg_parser.get_int("group_count"); + const int repeat = arg_parser.get_int("repeat"); + const int warmup = arg_parser.get_int("warmup"); + const int kbatch = arg_parser.get_int("kbatch"); + const int init_method = arg_parser.get_int("init"); + bool validate = arg_parser.get_bool("validate"); if(kbatch > 1 && validate && warmup + repeat > 1) { @@ -259,9 +261,9 @@ int run_grouped_gemm_example_with_layouts(int argc, } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { - AQK = 0; // No A quantization - BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize - if(K % QuantGroupSize != 0) + AQK = 0; // No A quantization + BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize + if(K % QuantGroupSize::kK != 0) { throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode"); } @@ -400,6 +402,7 @@ int run_grouped_gemm_example_with_layouts(int argc, BLayout, BQLayout, CLayout, + QuantGroupSize, QuantMode>(warmup, repeat, group_count, gemm_descs); for(int i = 0; i < group_count; i++) @@ -481,12 +484,14 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Types = GemmTypeConfig; // Specific type aliases for easy access - using ADataType = typename Types::ADataType; - using BDataType = typename Types::BDataType; - using AccDataType = typename Types::AccDataType; - using CDataType = typename Types::CDataType; - using AQDataType = typename Types::AccDataType; - using BQDataType = typename Types::AccDataType; + using ADataType = typename Types::ADataType; + using BDataType = typename Types::BDataType; + using AccDataType = typename Types::AccDataType; + using CDataType = typename Types::CDataType; + using AQDataType = typename Types::AccDataType; + using BQDataType = typename Types::AccDataType; + using QuantGroupSize = ck_tile::QuantGroupShape>; + if(a_layout == "R" && b_layout == "C") { return run_grouped_gemm_example_with_layouts( argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); } diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index 38d76410a9..8eb1ee4496 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -390,6 +390,7 @@ struct QuantGroupedGemmKernel const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, b_block_window, bq_block_window, + kargs.N, num_loop, tail_num, smem_ptr_0,