From b2019db49542cb72daa3a3f7fb9ea26bfae6deda Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Fri, 5 Dec 2025 18:14:37 +0000 Subject: [PATCH] Merge commit '6b1bceca7baea62941793e562d6ff58c571d9191' into develop --- .../gemm_bquant_quantgrouped_preshuffleb.cpp | 192 ++++++++++++++++-- .../run_gemm_quant_example.inc | 27 ++- include/ck_tile/host/tensor_shuffle_utils.hpp | 10 +- ...ock_universal_gemm_ar_flatbr_bquant_cr.hpp | 13 +- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 1 - .../test_gemm_quant_bquant_preshuffle.cpp | 44 +++- .../test_gemm_quant_fixtures.hpp | 6 +- 7 files changed, 257 insertions(+), 36 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp index 8ebf5bbd96..b32356c29d 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp @@ -14,36 +14,154 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; void bquant_quantgrouped_preshuffleb_instance_factory( std::unordered_map>& lut) { - using QuantGroupSize = ck_tile::QuantGroupShape>; lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; lut[hash_multiple_strings( {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, @@ -52,10 +170,50 @@ void bquant_quantgrouped_preshuffleb_instance_factory( lut[hash_multiple_strings( {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 0ee19b4a26..8a0dd9bc08 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -140,6 +140,13 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::WPQuantBPipelineAgBgCrV2, ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; + constexpr bool TiledPermuteN = + (QuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN; + if(s.log_level_ > 0) + { + printf( + "TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, QuantGroupSize::kN); + } using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; + TiledPermuteN>>; using Kernel = ck_tile::QuantGemmKernel; @@ -382,7 +389,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, "K must be aligned with QuantGroupSize for AQuantGrouped/BQuantGrouped mode"); } } - ck_tile::index_t AQK, BQK; + ck_tile::index_t AQK, BQK, BQN = 0; if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize @@ -392,6 +399,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { AQK = 0; // No A quantization BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize + BQN = ck_tile::integer_divide_ceil(N, QuantGroupSize::kN); } else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant || QuantMode == ck_tile::QuantType::TensorQuant) @@ -431,7 +439,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { stride_AQ = 0; // No A quantization - stride_BQ = ck_tile::get_default_stride(BQK, N, stride_BQ, is_row_major(bq_layout)); + stride_BQ = ck_tile::get_default_stride(BQK, BQN, stride_BQ, is_row_major(bq_layout)); } else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) { @@ -471,7 +479,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, QuantMode == ck_tile::QuantType::RowColQuant) { bq_tensor_ptr = std::make_unique>( - ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout))); + ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout))); } else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) { @@ -557,7 +565,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, b_k_n.SetZero(); bq_tensor_ptr->SetZero(); } - ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); @@ -610,7 +617,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::HostTensor b_k_n_dev = b_k_n; if constexpr(GemmConfig::PreshuffleB) { - if constexpr(GemmConfig::TiledMMAPermuteN) + if constexpr(GemmConfig::TiledMMAPermuteN && QuantGroupSize::kN == 1) { printf("PreshuffleB with TiledMMAPermuteN\n"); b_k_n_dev = ck_tile::shuffle_b_permuteN(b_k_n); @@ -635,11 +642,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, QuantMode == ck_tile::QuantType::RowColQuant || QuantMode == ck_tile::QuantType::TensorQuant) { - if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN) + if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN && + QuantGroupSize::kN == 1) { - printf("Preshuffle BQ with TiledMMAPermuteN \n"); ck_tile::HostTensor bq_permuted_host = - ck_tile::bq_permuteN(*bq_tensor_ptr); + ck_tile::bq_permuteN(*bq_tensor_ptr, QuantGroupSize::kN); if constexpr(GemmConfig::PreshuffleQuant) { @@ -659,7 +666,9 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data()); } else + { bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data()); + } } invoke_gemm* t, int block_aq_k) } int m_ = t->get_lengths()[0]; int aqk_ = t->get_lengths()[1]; + if(aqk_ % block_aq_k != 0) { throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k."); @@ -110,7 +111,7 @@ auto shuffle_b(const ck_tile::HostTensor& t) } template -auto bq_permuteN(const ck_tile::HostTensor& t) +auto bq_permuteN(const ck_tile::HostTensor& t, index_t group_n) { assert(t.get_lengths().size() == 2); @@ -118,8 +119,11 @@ auto bq_permuteN(const ck_tile::HostTensor& t) int bqk_ = t.get_lengths()[0]; constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; - ck_tile::HostTensor t_view( - {n_ / GemmConfig::N_Tile, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, bqk_}); + ck_tile::HostTensor t_view({n_ / (GemmConfig::N_Tile / group_n), + GemmConfig::N_Warp, + GemmConfig::N_Warp_Tile / group_n, + NRepeat, + bqk_}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4}); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index b54a93614a..58b713cb35 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -28,7 +28,6 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg using QuantGroupSize = remove_cvref_t; static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); - static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!"); static constexpr auto I0 = number<0>(); static constexpr auto I1 = number<1>(); @@ -205,7 +204,17 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg } else { - constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale; + index_t reg_offset = [&]() { + if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN)) + { + return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ + + kQScale; + } + else + { + return nIter * KPerBlockBQ + kQScale; + } + }(); auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; float scale_reg_f = cvt_scale_to_fp32(scale_reg); diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index f6cf4ce9be..dd85705cf2 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -747,7 +747,6 @@ struct QuantGemmKernel (splitk_batch_offset.splitted_k / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); index_t kFlatN = kargs.N * kargs.K / kFlatK; - return make_naive_tensor_view( b_ptr, make_tuple(kFlatN, kFlatK), diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp index 59b267842f..6cde4bded5 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp @@ -19,6 +19,12 @@ using PkInt4 = ck_tile::pk_int4_t; using BQuantGrouped = std::integral_constant; using GroupSize = ck_tile::QuantGroupShape>; +// 2d block sizes for BQuant +using GroupSize2D8N = ck_tile::QuantGroupShape>; +using GroupSize2D16N = ck_tile::QuantGroupShape>; +using GroupSize2D32N = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; + // Type combinations for BQuant tests with PreshuffleB // Tuple format: @@ -37,7 +43,43 @@ using BPreshuffleBQuantTypes = ::testing::Types< std::tuple, std::tuple, std::tuple, - std::tuple + std::tuple, + + // //2d cases with preshuffle B + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 3b62d8073e..7b16529aa8 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -433,7 +433,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase b_k_n_dev = b_k_n; if constexpr(PreshuffleB) { - if constexpr(TiledMMAPermuteN) + if constexpr(TiledMMAPermuteN && QuantGroupSize::kN == 1) { printf("PreshuffleB with TiledMMAPermuteN\n"); b_k_n_dev = ck_tile::shuffle_b_permuteN(b_k_n); @@ -451,11 +451,11 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase bq_shuffle_host = - ck_tile::bq_permuteN(bq_bqk_bqn); + ck_tile::bq_permuteN(bq_bqk_bqn, QuantGroupSize::kN); bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data()); } else if constexpr(GemmConfig::PreshuffleQuant)