From 7d6cd1f3c4dc7c2d28d4856d64cbfe4a39806fa3 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Mon, 24 Nov 2025 07:48:42 -0800 Subject: [PATCH] [CK_Tile] Support for preshuffle weight(B) quant tensor for block scale gemm (#3165) * formatted * formatted * formatting * formatting * formatting * [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * enable prefill shapes * [CK TILE GEMM] Refactor block_scale_gemm examples - Add support for rowcol and tensor GEMM operations * [CK TILE GEMM] Refactor block_scale_gemm examples - Update README * adding preshuffle quant as new parameter and its associated new files * remove debugging statements * adding test * enable preshuffle quant with permuteN * updating readme and correcponding gemmconfigs * updating cmake file * fixing CI failures for grouped quant gemm * addressing review comments * fixing CI issue * addressing reveiw comments * formatting * formatting * fixing aquant operator overlaoding * formatting --------- Co-authored-by: Cong Ma Co-authored-by: Thomas Ning [ROCm/composable_kernel commit: 8111572785d3de98457940f2b5ca6fe9cf7603af] --- .../17_grouped_gemm/quant_grouped_gemm.cpp | 5 +- .../quant_run_grouped_gemm_example.inc | 38 ++-- .../38_block_scale_gemm/CMakeLists.txt | 13 +- example/ck_tile/38_block_scale_gemm/README.md | 71 +++---- .../gemm_aquant_quantgrouped.cpp | 16 +- ...m_aquant_quantgrouped_preshufflequant.cpp} | 30 +-- ...8.cpp => gemm_bquant_quantgrouped_bf8.cpp} | 15 +- ...cpp => gemm_bquant_quantgrouped_bf8i4.cpp} | 15 +- ...8.cpp => gemm_bquant_quantgrouped_fp8.cpp} | 15 +- ...cpp => gemm_bquant_quantgrouped_fp8i4.cpp} | 15 +- .../gemm_bquant_quantgrouped_preshuffleb.cpp | 59 ++++++ ...antgrouped_preshuffleb_preshufflequant.cpp | 57 ++++++ ...mm_bquant_quantgrouped_preshufflequant.cpp | 59 ++++++ .../38_block_scale_gemm/gemm_quant.cpp | 26 ++- .../38_block_scale_gemm/gemm_quant_rowcol.cpp | 2 +- .../38_block_scale_gemm/gemm_quant_tensor.cpp | 2 +- .../38_block_scale_gemm/gemm_utils.hpp | 28 ++- .../run_gemm_quant_example.inc | 19 +- include/ck_tile/host/tensor_shuffle_utils.hpp | 45 ++++- ...ock_universal_gemm_ar_flatbr_bquant_cr.hpp | 49 ++++- .../block_universal_gemm_as_aquant_bs_cr.hpp | 3 +- .../block_universal_gemm_as_bs_bquant_cr.hpp | 76 ++++++-- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 180 +++++++++++++++--- .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 5 +- .../gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 4 +- .../gemm_bquant_pipeline_ag_bg_cr_policy.hpp | 60 ++++-- .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 22 ++- .../pipeline/gemm_group_quant_utils.hpp | 102 ++++++---- .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 51 ++++- .../test_gemm_quant_fixtures.hpp | 13 +- .../test_gemm_quant_typed.cpp | 7 +- 31 files changed, 855 insertions(+), 247 deletions(-) rename example/ck_tile/38_block_scale_gemm/{gemm_bquant_quantgrouped_preshuffleb_prefill.cpp => gemm_aquant_quantgrouped_preshufflequant.cpp} (74%) rename example/ck_tile/38_block_scale_gemm/{gemm_bquant_quantgrouped_prefill_bf8.cpp => gemm_bquant_quantgrouped_bf8.cpp} (76%) rename example/ck_tile/38_block_scale_gemm/{gemm_bquant_quantgrouped_prefill_bf8i4.cpp => gemm_bquant_quantgrouped_bf8i4.cpp} (77%) rename example/ck_tile/38_block_scale_gemm/{gemm_bquant_quantgrouped_prefill_fp8.cpp => gemm_bquant_quantgrouped_fp8.cpp} (76%) rename example/ck_tile/38_block_scale_gemm/{gemm_bquant_quantgrouped_prefill_fp8i4.cpp => gemm_bquant_quantgrouped_fp8i4.cpp} (77%) create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index 59ff086dca..38f7544d97 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -29,6 +29,7 @@ template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, @@ -49,7 +50,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, GemmConfig::kPadN, GemmConfig::kPadK, false, // PreshuffleQuant - GemmConfig::PreshuffleB, // PreshuffleB + GemmConfig::PreshuffleB, ALayout, BLayout, CLayout, @@ -75,7 +76,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, AccDataType, GemmShape, GemmUniversalTraits, - 128>, // QuantGroupSize + QuantGroupSize>, ck_tile::GemmRowColTensorQuantPipelineProblem float invoke_gemm(int n_warmup, @@ -104,6 +105,7 @@ float invoke_gemm(int n_warmup, BQDataType, AccDataType, CDataType, + QuantGroupSize, QuantMode>(stream, group_count, kargs_ptr); std::string op_name = "Quant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")"; @@ -134,6 +136,7 @@ template (group_count)) && ...); }; - const int group_count = arg_parser.get_int("group_count"); - const int repeat = arg_parser.get_int("repeat"); - const int warmup = arg_parser.get_int("warmup"); - const int kbatch = arg_parser.get_int("kbatch"); - const int init_method = arg_parser.get_int("init"); - bool validate = arg_parser.get_bool("validate"); - const ck_tile::index_t QuantGroupSize = 128; + const int group_count = arg_parser.get_int("group_count"); + const int repeat = arg_parser.get_int("repeat"); + const int warmup = arg_parser.get_int("warmup"); + const int kbatch = arg_parser.get_int("kbatch"); + const int init_method = arg_parser.get_int("init"); + bool validate = arg_parser.get_bool("validate"); if(kbatch > 1 && validate && warmup + repeat > 1) { @@ -259,9 +261,9 @@ int run_grouped_gemm_example_with_layouts(int argc, } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { - AQK = 0; // No A quantization - BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize - if(K % QuantGroupSize != 0) + AQK = 0; // No A quantization + BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize + if(K % QuantGroupSize::kK != 0) { throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode"); } @@ -400,6 +402,7 @@ int run_grouped_gemm_example_with_layouts(int argc, BLayout, BQLayout, CLayout, + QuantGroupSize, QuantMode>(warmup, repeat, group_count, gemm_descs); for(int i = 0; i < group_count; i++) @@ -481,12 +484,14 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Types = GemmTypeConfig; // Specific type aliases for easy access - using ADataType = typename Types::ADataType; - using BDataType = typename Types::BDataType; - using AccDataType = typename Types::AccDataType; - using CDataType = typename Types::CDataType; - using AQDataType = typename Types::AccDataType; - using BQDataType = typename Types::AccDataType; + using ADataType = typename Types::ADataType; + using BDataType = typename Types::BDataType; + using AccDataType = typename Types::AccDataType; + using CDataType = typename Types::CDataType; + using AQDataType = typename Types::AccDataType; + using BQDataType = typename Types::AccDataType; + using QuantGroupSize = ck_tile::QuantGroupShape>; + if(a_layout == "R" && b_layout == "C") { return run_grouped_gemm_example_with_layouts( argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); } diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 932acb72fd..40a4166126 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -10,11 +10,14 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_executable(${EXE_NAME} EXCLUDE_FROM_ALL gemm_quant.cpp gemm_aquant_quantgrouped.cpp - gemm_bquant_quantgrouped_prefill_bf8i4.cpp - gemm_bquant_quantgrouped_prefill_fp8i4.cpp - gemm_bquant_quantgrouped_prefill_bf8.cpp - gemm_bquant_quantgrouped_prefill_fp8.cpp - gemm_bquant_quantgrouped_preshuffleb_prefill.cpp + gemm_aquant_quantgrouped_preshufflequant.cpp + gemm_bquant_quantgrouped_bf8i4.cpp + gemm_bquant_quantgrouped_fp8i4.cpp + gemm_bquant_quantgrouped_bf8.cpp + gemm_bquant_quantgrouped_fp8.cpp + gemm_bquant_quantgrouped_preshuffleb.cpp + gemm_bquant_quantgrouped_preshufflequant.cpp + gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp gemm_quant_rowcol.cpp gemm_quant_tensor.cpp ) diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index 64ecebd15a..b81c5de7ab 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -33,47 +33,50 @@ mkdir build && cd build # you can replace with the appropriate architecture (for example gfx942) or leave it blank ../script/cmake-ck-dev.sh ../ # Compile the quant kernels -make tile_example_gemm_quant_basic -j +make tile_example_gemm_quant -j ``` -This will result in an executable `build/bin/tile_example_gemm_quant_basic` +This will result in an executable `build/bin/tile_example_gemm_quant` ## example ``` args: - -h Print help message (default:false) - -m m dimension (default:3840) - -n n dimension (default:4096) - -k k dimension (default:2048) - -a_layout A tensor data layout - Row or Column (default:R) - -b_layout B tensor data layout - Row or Column (default:C) - -bq_layout Bq tensor data layout - Row or Column (default:C) - -c_layout C tensor data layout - Row or Column (default:R) - -stride_a Tensor A stride (default:0) - -stride_q Tensor AQ stride (default:0) - -stride_b Tensor B stride (default:0) - -stride_c Tensor C stride (default:0) - -v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1) - -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8) - -warmup Number of iterations before benchmarking the kernel (default:50) - -repeat Number of iterations to benchmark the kernel (default:1000) - -timer gpu:gpu timer, cpu:cpu timer (default:gpu) - -split_k SplitK value (default:1) - -device Device id that will be used to run the kernel (default:0) - -init 0:random, 1:linear, 2:constant(1) (default:0) - -flush_cache Flush cache before running the kernel (default:true) --rotating_count Rotating count (default:1000) - -quant_mode Choose aquant, bquant, tensor or rowcol (default:bquant) - -preshuffleb Enable preshuffle of tensor B (default:false) - -group_size Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128 (default:1x1x128) + -h Print help message (default:false) + -m m dimension (default:3840) + -n n dimension (default:4096) + -k k dimension (default:2048) + -a_layout A tensor data layout - R for Row or C for Column (default:R) + -b_layout B tensor data layout - R for Row or C for Column (default:C) + -bq_layout Bq tensor data layout - R for Row or C for Column (default:C) + -c_layout C tensor data layout - R for Row or C for Column (default:R) + -stride_a Tensor A stride (default:0) + -stride_q Tensor AQ stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1) + -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8) + -warmup Number of iterations before benchmarking the kernel (default:50) + -repeat Number of iterations to benchmark the kernel (default:1000) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) + -split_k SplitK value (default:1) + -device Device id that will be used to run the kernel (default:0) + -init 0:random, 1:linear, 2:constant(1) (default:0) + -flush_cache Flush cache before running the kernel (default:true) + -rotating_count Rotating count (default:1000) + -quant_mode Choose aquant, bquant, tensor or rowcol (default:bquant) + -preshuffleb Enable preshuffle of tensor B (default:false) + -preshufflequant Enable preshuffle of quant tensor (default:false) + -group_size Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128 (default:1x1x128) ``` User need to select correct mapping of config for each quant mode: -| | quant_mode as runtime argument | Config in cpp file | -|:--------|:-----:|-------| -| For selecting AQuant | aquant | GemmConfigQuant | -| For selecting Aquant with Preshuffle | aquant | GemmConfigPreshuffleQuant | -| For selecting BQuant | bquant | GemmConfigQuant | -| For selecting PreShuffle Weight matrix with Bquant | bquant | GemmConfigPreshuffleB_Bquant_decode (or) GemmConfigPreshuffleB_Bquant_prefill -| For selecting RowCol quant | rowcolquant | GemmConfigRowColQuant | +| | quant_mode as runtime argument | Corresponding cpp file | GemmConfig at the top of cpp file | +|:--------|:-----:|:-----:|-------| +| For selecting AQuant | aquant | gemm_aquant_quantgrouped.cpp| GemmConfigQuantDecode | +| For selecting AQuant with Preshuffle quant | aquant | gemm_aquant_quantgrouped_preshufflequant.cpp | GemmConfigPreshuffleQuantDecode | +| For selecting BQuant | bquant | gemm_bquant_quantgrouped_.cpp| GemmConfigQuantDecode (or) GemmConfigBQuantPrefill | +| For selecting BQuant with Preshuffle quant | bquant | gemm_bquant_quantgrouped_preshufflequant.cpp| GemmConfigPreshuffleQuantDecode (or) GemmConfigPreshuffleBQuantPrefill | +| For selecting PreShuffle B with BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb.cpp| GemmConfigPreshuffleB_BQuant_Decode (or) GemmConfigPreshuffleB_BQuant_Prefill +| For selecting PreShuffle B with preshuffle BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp |GemmConfigPreshuffleB_PreshuffleBQuant_Decode (or) GemmConfigPreshuffleB_PreshuffleBQuant_Prefill +| For selecting RowCol quant | rowcolquant | gemm_quant_rowcol| GemmConfigRowColQuant | diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp index 3786230ff0..793528f5bb 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp @@ -4,14 +4,15 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuant; +using GemmConfig = GemmConfigQuantDecode; void aquant_quantgrouped_instance_factory( std::unordered_map>& lut) { using QuantGroupSize = ck_tile::QuantGroupShape>; - lut[hash_multiple_strings({"fp8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser& - arg_parser) { + lut[hash_multiple_strings( + {"fp8", "aquant", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& + arg_parser) { using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type, @@ -19,8 +20,9 @@ void aquant_quantgrouped_instance_factory( QuantGroupSize, ck_tile::QuantType::AQuantGrouped>(arg_parser); }; - lut[hash_multiple_strings({"bf8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser& - arg_parser) { + lut[hash_multiple_strings( + {"bf8", "aquant", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& + arg_parser) { using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type, @@ -28,7 +30,7 @@ void aquant_quantgrouped_instance_factory( QuantGroupSize, ck_tile::QuantType::AQuantGrouped>(arg_parser); }; - lut[hash_multiple_strings({"fp8i4", "aquant", "1x1x128"})] = + lut[hash_multiple_strings({"fp8i4", "aquant", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { using TypeConfig = decltype(GemmQuantTypeConfig(arg_parser); }; - lut[hash_multiple_strings({"bf8i4", "aquant", "1x1x128"})] = + lut[hash_multiple_strings({"bf8i4", "aquant", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { using TypeConfig = decltype(GemmQuantTypeConfig -using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill; +using GemmConfig = GemmConfigPreshuffleQuantDecode; -void bquant_quantgrouped_preshuffleb_instance_factory( +void aquant_quantgrouped_preshufflequant_instance_factory( std::unordered_map>& lut) { using QuantGroupSize = ck_tile::QuantGroupShape>; lut[hash_multiple_strings( - {"fp8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + {"fp8", "aquant", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& + arg_parser) { using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); + ck_tile::QuantType::AQuantGrouped>(arg_parser); }; lut[hash_multiple_strings( - {"bf8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + {"bf8", "aquant", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& + arg_parser) { using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); + ck_tile::QuantType::AQuantGrouped>(arg_parser); }; - lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "1x1x128"})] = + lut[hash_multiple_strings({"fp8i4", "aquant", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); + ck_tile::QuantType::AQuantGrouped>(arg_parser); }; - lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "1x1x128"})] = + lut[hash_multiple_strings({"bf8i4", "aquant", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); + ck_tile::QuantType::AQuantGrouped>(arg_parser); }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp similarity index 76% rename from example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8.cpp rename to example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp index cb9f8b62cf..2c2e55a730 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp @@ -18,28 +18,33 @@ void bquant_quantgrouped_bf8_instance_factory( using TypeConfig = decltype(GemmQuantTypeConfig{}); #ifndef CK_GFX950_SUPPORT - lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x64"})] = + lut[hash_multiple_strings( + {"bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; #endif - lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x128"})] = + lut[hash_multiple_strings( + {"bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; - lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x8x128"})] = + lut[hash_multiple_strings( + {"bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x8x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; - lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x32x128"})] = + lut[hash_multiple_strings( + {"bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; - lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x64x128"})] = + lut[hash_multiple_strings( + {"bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp similarity index 77% rename from example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8i4.cpp rename to example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp index 33ae3bc4a9..06d02d70c5 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp @@ -20,28 +20,33 @@ void bquant_quantgrouped_bf8i4_instance_factory( ck_tile::half_t, ck_tile::bf8_t>{}); #ifndef CK_GFX950_SUPPORT - lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x64"})] = + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; #endif - lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x128"})] = + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; - lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x8x128"})] = + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x8x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; - lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x32x128"})] = + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; - lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x64x128"})] = + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp similarity index 76% rename from example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8.cpp rename to example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp index 526c35b081..e4ff554267 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp @@ -18,28 +18,33 @@ void bquant_quantgrouped_fp8_instance_factory( using TypeConfig = decltype(GemmQuantTypeConfig{}); #ifndef CK_GFX950_SUPPORT - lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x64"})] = + lut[hash_multiple_strings( + {"fp8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; #endif - lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x128"})] = + lut[hash_multiple_strings( + {"fp8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; - lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x8x128"})] = + lut[hash_multiple_strings( + {"fp8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x8x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; - lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x32x128"})] = + lut[hash_multiple_strings( + {"fp8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; - lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x64x128"})] = + lut[hash_multiple_strings( + {"fp8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp similarity index 77% rename from example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8i4.cpp rename to example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp index 4b2a8efb14..8839101c42 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp @@ -20,28 +20,33 @@ void bquant_quantgrouped_fp8i4_instance_factory( ck_tile::half_t, ck_tile::fp8_t>{}); #ifndef CK_GFX950_SUPPORT - lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x64"})] = + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; #endif - lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x128"})] = + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; - lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x8x128"})] = + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x8x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; - lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x32x128"})] = + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; - lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x64x128"})] = + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp new file mode 100644 index 0000000000..898316fa6b --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; + +void bquant_quantgrouped_preshuffleb_instance_factory( + std::unordered_map>& lut) +{ + using QuantGroupSize = ck_tile::QuantGroupShape>; + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp new file mode 100644 index 0000000000..e26edb6501 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill; + +void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory( + std::unordered_map>& lut) +{ + using QuantGroupSize = ck_tile::QuantGroupShape>; + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp new file mode 100644 index 0000000000..82967d5be2 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleBQuantPrefill; + +void bquant_quantgrouped_preshufflequant_instance_factory( + std::unordered_map>& lut) +{ + using QuantGroupSize = ck_tile::QuantGroupShape>; + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index a35f867f5d..216e3cdfb8 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -43,6 +43,7 @@ auto create_args(int argc, char* argv[]) .insert("rotating_count", "1000", "Rotating count") .insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol") .insert("preshuffleb", "false", "Enable preshuffle of tensor B") + .insert("preshufflequant", "false", "Enable preshuffle of quant tensor") .insert("group_size", "1x1x128", "Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128"); @@ -58,11 +59,21 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser) std::vector params = {data_type, quant_mode}; + if(quant_mode == "aquant") + { + std::string preshufflequant = + arg_parser.get_bool("preshufflequant") ? "preshufflequant" : "non-preshufflequant"; + params.push_back(preshufflequant); + } if(quant_mode == "bquant") { std::string preshuffleb = arg_parser.get_bool("preshuffleb") ? "preshuffleb" : "non-preshuffleb"; params.push_back(preshuffleb); + + std::string preshufflequant = + arg_parser.get_bool("preshufflequant") ? "preshufflequant" : "non-preshufflequant"; + params.push_back(preshufflequant); } if(quant_mode != "rowcol" && quant_mode != "tensor") { @@ -76,6 +87,8 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser) void aquant_quantgrouped_instance_factory( std::unordered_map>& lut); +void aquant_quantgrouped_preshufflequant_instance_factory( + std::unordered_map>& lut); void bquant_quantgrouped_fp8_instance_factory( std::unordered_map>& lut); void bquant_quantgrouped_bf8_instance_factory( @@ -86,6 +99,10 @@ void bquant_quantgrouped_bf8i4_instance_factory( std::unordered_map>& lut); void bquant_quantgrouped_preshuffleb_instance_factory( std::unordered_map>& lut); +void bquant_quantgrouped_preshufflequant_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory( + std::unordered_map>& lut); void quant_rowcol_instance_factory( std::unordered_map>& lut); void quant_tensor_instance_factory( @@ -106,11 +123,14 @@ int main(int argc, char* argv[]) std::unordered_map> lut; aquant_quantgrouped_instance_factory(lut); + aquant_quantgrouped_preshufflequant_instance_factory(lut); bquant_quantgrouped_fp8_instance_factory(lut); bquant_quantgrouped_bf8_instance_factory(lut); bquant_quantgrouped_fp8i4_instance_factory(lut); bquant_quantgrouped_bf8i4_instance_factory(lut); bquant_quantgrouped_preshuffleb_instance_factory(lut); + bquant_quantgrouped_preshufflequant_instance_factory(lut); + bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(lut); quant_rowcol_instance_factory(lut); quant_tensor_instance_factory(lut); @@ -122,9 +142,9 @@ int main(int argc, char* argv[]) } else { - std::cerr - << "Error: Combination of prec, quant_mode, preshuffleb, and group_size not supported." - << std::endl; + std::cerr << "Error: Combination of prec, quant_mode, preshuffleb, preshufflequant, and " + "group_size not supported." + << std::endl; return -1; } } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_rowcol.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_rowcol.cpp index 2d9e4e2c6d..648ca27c78 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_rowcol.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_rowcol.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuant; +using GemmConfig = GemmConfigQuantDecode; void quant_rowcol_instance_factory( std::unordered_map>& lut) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_tensor.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_tensor.cpp index 21207373a7..c660c742f7 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_tensor.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_tensor.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuant; +using GemmConfig = GemmConfigQuantDecode; void quant_tensor_instance_factory( std::unordered_map>& lut) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index cf120e1dd0..fcc5c00327 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -110,7 +110,7 @@ struct GemmConfigBase }; template -struct GemmConfigQuant : public GemmConfigBase +struct GemmConfigQuantDecode : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 16; static constexpr ck_tile::index_t N_Tile = 64; @@ -142,7 +142,7 @@ struct GemmConfigRowColQuant : public GemmConfigBase }; template -struct GemmConfigPreshuffleQuant : public GemmConfigBase +struct GemmConfigPreshuffleQuantDecode : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 16; static constexpr ck_tile::index_t N_Tile = 64; @@ -161,7 +161,7 @@ struct GemmConfigPreshuffleQuant : public GemmConfigBase }; template -struct GemmConfigPreshuffleB_Bquant_decode : public GemmConfigBase +struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 16; static constexpr ck_tile::index_t N_Tile = 64; @@ -184,7 +184,14 @@ struct GemmConfigPreshuffleB_Bquant_decode : public GemmConfigBase }; template -struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase +struct GemmConfigPreshuffleB_PreshuffleBQuant_Decode + : public GemmConfigPreshuffleB_BQuant_Decode +{ + static constexpr bool PreshuffleQuant = true; +}; + +template +struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -206,6 +213,13 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; +template +struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill + : public GemmConfigPreshuffleB_BQuant_Prefill +{ + static constexpr bool PreshuffleQuant = true; +}; + template struct GemmConfigBQuantPrefill : public GemmConfigBase { @@ -222,6 +236,12 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; +template +struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigBQuantPrefill +{ + static constexpr bool PreshuffleQuant = true; +}; + template struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill { diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 5089a6ea9a..a573e5c765 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -557,7 +557,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, aq_dev_buf_ptr = std::make_unique(aq_tensor_ptr->get_element_space_size_in_bytes()); } - std::unique_ptr bq_dev_buf_ptr = nullptr; if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped || QuantMode == ck_tile::QuantType::RowColQuant || @@ -626,8 +625,24 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN) { printf("Preshuffle BQ with TiledMMAPermuteN \n"); + ck_tile::HostTensor bq_permuted_host = + ck_tile::bq_permuteN(*bq_tensor_ptr); + + if constexpr(GemmConfig::PreshuffleQuant) + { + ck_tile::HostTensor bq_shuffle_host = + ck_tile::shuffle_bq(&bq_permuted_host, GemmConfig::K_Tile / QuantGroupSize::kK); + bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data()); + } + else + { + bq_dev_buf_ptr->ToDevice(bq_permuted_host.data()); + } + } + else if constexpr(GemmConfig::PreshuffleQuant) + { ck_tile::HostTensor bq_shuffle_host = - ck_tile::shuffle_bq_permuteN(*bq_tensor_ptr); + ck_tile::shuffle_bq(bq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize::kK); bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data()); } else diff --git a/include/ck_tile/host/tensor_shuffle_utils.hpp b/include/ck_tile/host/tensor_shuffle_utils.hpp index e3b5c96d91..8be32fa910 100644 --- a/include/ck_tile/host/tensor_shuffle_utils.hpp +++ b/include/ck_tile/host/tensor_shuffle_utils.hpp @@ -20,6 +20,49 @@ auto shuffle_aq(const ck_tile::HostTensor* t, int block_aq_k) return ck_tile::reference_permute(t_view, {1, 0, 2}); } +template +auto shuffle_bq(const ck_tile::HostTensor* t, int block_bq_k) +{ + const auto& lengths = t->get_lengths(); + const size_t rank = lengths.size(); + + // Validate block_bq_k divisibility based on rank + int bqk_dim = (rank == 5) ? lengths[4] : (rank == 2) ? lengths[0] : -1; + + if(bqk_dim < 0) + { + throw std::runtime_error("shuffle_bq expects either rank-2 or rank-5 tensor, got rank " + + std::to_string(rank)); + } + + if(bqk_dim % block_bq_k != 0) + { + throw std::runtime_error("shuffle_bq needs bqk dimension to be a multiple of block_bq_k."); + } + + // For TilePermuteN + if(rank == 5) + { + // Handle 5D tensor: [n, nrepeat, nwarp, n_warp_tile, bqk] + ck_tile::HostTensor t_view({static_cast(lengths[0]), + static_cast(lengths[1]), + static_cast(lengths[2]), + static_cast(lengths[3]), + bqk_dim / block_bq_k, + block_bq_k}); + std::copy(t->begin(), t->end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {4, 0, 1, 2, 3, 5}); + } + else // rank == 2 + { + // Handle 2D tensor: [bqk, n] + int n_ = lengths[1]; + ck_tile::HostTensor t_view({n_, bqk_dim / block_bq_k, block_bq_k}); + std::copy(t->begin(), t->end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {1, 0, 2}); + } +} + template auto shuffle_b(const ck_tile::HostTensor& t) { @@ -64,7 +107,7 @@ auto shuffle_b(const ck_tile::HostTensor& t) } template -auto shuffle_bq_permuteN(const ck_tile::HostTensor& t) +auto bq_permuteN(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index 6422c07e1d..392dc46f72 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -54,6 +54,8 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); static constexpr index_t NIterPerWarp = BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN)); @@ -172,16 +174,47 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg c_warp_y_index_zeros)) / CBlockTensor::PackedSize>{}; - constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale; + if constexpr(PreshuffleQuant) + { + constexpr index_t reg_offset = nIter; + auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale; + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + // cross lane ops + uint32_t scale_reg_dword; - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - float scale_reg_f = cvt_scale_to_fp32(scale_reg); + if constexpr(std::is_same_v) + { + scale_reg_dword = ck_tile::bit_cast(scale_reg); + } + else + { + scale_reg_dword = static_cast(scale_reg); + } - static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { - auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; - const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; - c_ref = c_ref + acc_val * scale_reg_f; - }); + // cross lane ops to get the value of scale_reg. + int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( + pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); + + float scale_reg_f = cvt_scale_to_fp32(gathered_scale_reg); + + static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { + auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; + const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; + c_ref = c_ref + acc_val * scale_reg_f; + }); + } + else + { + constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale; + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float scale_reg_f = cvt_scale_to_fp32(scale_reg); + + static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { + auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; + const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; + c_ref = c_ref + acc_val * scale_reg_f; + }); + } }); }); }); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 608de80a7a..eb59f89a69 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -274,7 +274,6 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); - return Base::cvt_scale_to_fp32(gathered_scale_reg); } @@ -368,7 +367,6 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase static_assert(false, "WarpGemm::kM is not 16 nor 32."); } auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter]; - return exchange_quant_value_across_lanes(scale_reg, pull_from_lane); } else @@ -511,6 +509,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( [&](auto c_row) { float scale_reg_f = aq_picker.template pick(); + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); }); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 41ed272d0d..b7c0eb2198 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -100,6 +100,8 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + static constexpr index_t QScalesPerBlockRow = integer_divide_ceil(KPerBlock, QuantGroupSize::kK); static constexpr index_t QScalesPerWarpGemmRow = @@ -173,6 +175,8 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase using BWarpTensor = typename WarpGemm::BWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor; + static constexpr bool PreshuffleQuant = Traits::PreshuffleQuant; + static_assert(std::is_same_v); static constexpr auto a_warp_y_lengths = @@ -321,31 +325,65 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase } }); - // Multiply bquant with accumulated C - constexpr index_t reg_offset = [&]() { - if constexpr(GemmTraits::QuantGroupSize::kN >= (NWarp * WarpGemm::kN)) - return (nIter * NWarp * WarpGemm::kN) / - GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock + - kQScale; - else - { - return nIter * Traits::KQPerBlock + kQScale; - } - }(); - constexpr auto tbuf_offset = number{}, c_warp_y_index_zeros)) / CBlockTensor::PackedSize>{}; - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg); - static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( - [&](auto c_row) { - c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += - (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); - }); + if constexpr(PreshuffleQuant) + { + constexpr index_t reg_offset = nIter; + auto pull_from_lane = + (__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale; + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + // cross lane ops + uint32_t scale_reg_dword; + + if constexpr(std::is_same_v) + { + scale_reg_dword = ck_tile::bit_cast(scale_reg); + } + else + { + scale_reg_dword = static_cast(scale_reg); + } + + // cross lane ops to get the value of scale_reg. + int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( + pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); + + float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg); + + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); + }); + } + else + { + // Multiply bquant with accumulated C + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::QuantGroupSize::kN >= + (NWarp * WarpGemm::kN)) + return (nIter * NWarp * WarpGemm::kN) / + GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock + + kQScale; + else + { + return nIter * Traits::KQPerBlock + kQScale; + } + }(); + + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg); + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); + }); + } }); }); }); diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 6c90d5c1a6..245ce4fa89 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -271,6 +271,94 @@ struct QuantGemmKernel return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } + private: + CK_TILE_DEVICE static constexpr index_t get_padding_size(index_t length, index_t alignment) + { + return ck_tile::integer_least_multiple(length, alignment) - length; + }; + // =================================================================== + // Helper: Create Pre-shuffled Quantization Tensor Descriptor + // =================================================================== + template + CK_TILE_DEVICE static auto + MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QK_B) + { + // Step 1: Calculate base BQ tensor dimensions + // ---------------------------------------------------------- + // bq_x: Number of quantization groups in N dimension + // = N * KPerBlockBQ, where KPerBlockBQ is the number of + // K-dimension groups per block + // bq_y: Number of quantization groups in K dimension + // = Total K groups (QK_B) / groups per block + const auto bq_x = N * KPerBlockBQ; + const auto bq_y = QK_B / KPerBlockBQ; + + const auto bq_desc = make_naive_tensor_descriptor( + make_tuple(bq_y, bq_x), make_tuple(bq_x, 1), number{}, number<1>{}); + + // Step 2: First padding transformation (block-level alignment) + // ---------------------------------------------------------- + // Pad the X dimension to be a multiple of block_tile_size to ensure + // each thread block can process complete tiles without edge cases + const auto block_tile_size = NPerBlock * KPerBlockBQ; + const auto bq_pad0_desc = transform_tensor_descriptor( + bq_desc, + make_tuple(make_pass_through_transform(bq_y), + make_right_pad_transform(bq_x, get_padding_size(bq_x, block_tile_size))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + // Step 3: Unmerge transformation (wave-level decomposition) + // ---------------------------------------------------------- + // Split the X dimension into [wave_tile_count_x, wave_tile_size] + // This separates the work into tiles that can be processed by + // individual warps/waves + const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1]; + const auto wave_tile_size = WarpTileN * KPerBlockBQ; + const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); + + const auto bq_unmerge_pad0_desc = transform_tensor_descriptor( + bq_pad0_desc, + make_tuple(make_pass_through_transform(bq_y), + make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{})); + + // Step 4: Second padding transformation (warp-level alignment) + // ---------------------------------------------------------- + // Pad wave_tile_size to be a multiple of warp_size (typically 32 or 64) + // This ensures coalesced memory accesses within each warp + const auto bq_pad1_desc = transform_tensor_descriptor( + bq_unmerge_pad0_desc, + make_tuple(make_pass_through_transform(bq_y), + make_pass_through_transform(wave_tile_count_x), + make_right_pad_transform(wave_tile_size, + get_padding_size(wave_tile_size, get_warp_size()))), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + // Step 5: Final merge transformation (prepare for indexing) + // ---------------------------------------------------------- + // Merge [bq_y, wave_tile_count_x] into a single outer dimension + // This creates a 2D layout: [merged_outer_dim, pad_wave_size] + // where merged_outer_dim = bq_y * wave_tile_count_x + // This layout facilitates efficient block-to-data mapping + const auto pad_wave_size = ck_tile::integer_least_multiple(wave_tile_size, get_warp_size()); + const auto bq_merge_pad1_desc = transform_tensor_descriptor( + bq_pad1_desc, + make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)), + make_pass_through_transform(pad_wave_size)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tensor_view(bq_ptr, bq_merge_pad1_desc); + } + + public: struct SplitKBatchOffset { __device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs, @@ -509,17 +597,12 @@ struct QuantGemmKernel } }(); - const auto get_padding_size = [](index_t length, index_t alignment) { - return ck_tile::integer_least_multiple(length, alignment) - length; - }; - const auto& aq_tensor_view = [&]() { if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant) { static_assert(std::is_same_v); const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ; const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ; - const auto aq_desc = make_naive_tensor_descriptor(make_tuple(aq_y, aq_x), make_tuple(aq_x, 1), @@ -540,6 +623,7 @@ struct QuantGemmKernel GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ; const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size); + const auto aq_unmerge_pad0_desc = transform_tensor_descriptor( aq_pad0_desc, make_tuple( @@ -686,14 +770,27 @@ struct QuantGemmKernel } else if constexpr(kQuantType == QuantType::BQuantGrouped) { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - return make_naive_tensor_view( - bq_ptr, - make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, QuantGroupSize::kN)), - make_tuple(1, kargs.stride_BQ), - number{}, - number<1>{}); + if constexpr(PreshuffleQuant) + { + static_assert(std::is_same_v); + + return MakePreshuffledQuantTensorView< + GemmPipeline::KPerBlockBQ, + GemmPipeline::NPerBlock, + TilePartitioner::BlockGemmShape::WarpTile::at(I1), + GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B); + } + else + { + static_assert(std::is_same_v); + using QuantGroupSize = remove_cvref_t; + return make_naive_tensor_view( + bq_ptr, + make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, QuantGroupSize::kN)), + make_tuple(1, kargs.stride_BQ), + number{}, + number<1>{}); + } } else { @@ -910,13 +1007,33 @@ struct QuantGemmKernel } else if constexpr(kQuantType == QuantType::BQuantGrouped) { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - return make_tile_window( - bq_pad_view, - make_tuple(number{}, - number{}), - {0, i_n / QuantGroupSize::kN}); + if constexpr(PreshuffleQuant) + { + static_assert(std::is_same_v); + using QuantGroupSize = remove_cvref_t; + constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN; + constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); + constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; + constexpr auto tile_window_width = + ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size()); + constexpr auto tile_window_height = block_n / warp_n; + auto block_n_idx = i_n / block_n; + + return make_tile_window( + bq_pad_view, + make_tuple(number{}, number{}), + {block_n_idx * tile_window_height, 0}); + } + else + { + static_assert(std::is_same_v); + using QuantGroupSize = remove_cvref_t; + return make_tile_window( + bq_pad_view, + make_tuple(number{}, + number{}), + {0, i_n / QuantGroupSize::kN}); + } } else { @@ -979,14 +1096,24 @@ struct QuantGemmKernel if constexpr(kQuantType == QuantType::AQuantGrouped) { const auto& aq_block_window = gemm_tile_windows.at(I1); + index_t m = 0; + if constexpr(PreshuffleQuant) + { + m = kargs.M; + } return GemmPipeline{}.template operator()( - a_block_window, b_block_window, aq_block_window, kargs.M, num_loop, smem_ptr_0); + a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0, m); } else if constexpr(kQuantType == QuantType::BQuantGrouped) { const auto& bq_block_window = gemm_tile_windows.at(I3); + index_t n = 0; + if constexpr(PreshuffleQuant) + { + n = kargs.N; + } return GemmPipeline{}.template operator()( - a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0); + a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n); } else if constexpr(kQuantType == QuantType::RowColQuant || kQuantType == QuantType::TensorQuant) @@ -1074,12 +1201,18 @@ struct QuantGemmKernel if constexpr(kQuantType == QuantType::BQuantGrouped) { const auto& bq_block_window = gemm_tile_windows.at(I3); + index_t n = 0; + if constexpr(PreshuffleQuant) + { + n = kargs.N; + } return GemmPipeline{}.template operator()(a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, - smem_ptr_1); + smem_ptr_1, + n); } else { @@ -1109,7 +1242,6 @@ struct QuantGemmKernel const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); - const SplitKBatchOffset splitk_batch_offset(kargs); // options const ADataType* a_ptr = static_cast(kargs.a_ptr); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 2568f1c07c..9dea74c425 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -463,11 +463,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem{} .template operator()( a_dram_block_window_tmp, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 77abc92656..41052cb485 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -465,9 +465,9 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BDramBlockWindowTmp& b_dram_block_window_tmp, const AQDramBlockWindowTmp& aq_dram_block_window_tmp, - index_t m, index_t num_loop, - void* p_smem) const + void* p_smem, + index_t m = 0) const { return PipelineImpl{}.template operator()( a_dram_block_window_tmp, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index 7f35865a2f..4f792e9de8 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -35,30 +35,48 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC using BQLayout = remove_cvref_t; using BlockGemmShape = typename Problem::BlockGemmShape; - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK; - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using WarpGemm = WarpGemmDispatcher; + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK; + constexpr index_t VecLoadSize = GetVectorSizeBQ(); + constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmDispatcher; static_assert(std::is_same_v); - using TileEncodingPattern = - tile_distribution_encoding_pattern_bq; + if constexpr(PreshuffleQuant) + { + using TileEncodingPattern = tile_distribution_encoding_pattern_bq< + BlockGemmShape, + WarpGemm, + BlockSize, + NPerBlock / WarpGemm::kN, + ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()), + VecLoadSize, + PreshuffleQuant>; + return TileEncodingPattern::make_2d_static_tile_distribution(); + } + else + { + using TileEncodingPattern = + tile_distribution_encoding_pattern_bq; - return TileEncodingPattern::make_2d_static_tile_distribution(); + return TileEncodingPattern::make_2d_static_tile_distribution(); + } } template diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 9820aeb658..584a15571c 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -137,6 +137,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV static constexpr bool kPadK = Problem::kPadK; static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr auto TailNum = Problem::TailNum; @@ -238,6 +239,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV const BDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t n, index_t num_loop, void* p_smem) const { @@ -257,9 +259,6 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV constexpr bool is_b_row_major = std::is_same_v; static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); - static_assert(KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}], - "Bq block window has incorrect lengths for defined BqLayout!"); static_assert(is_a_col_major ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && @@ -315,8 +314,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); constexpr BDramTileWindowStep b_dram_tile_window_step = is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); - constexpr BQDramTileWindowStep bq_dram_tile_window_step = - is_bq_col_major ? make_array(KPerBlockBQ, 0) : make_array(0, KPerBlockBQ); + const BQDramTileWindowStep bq_dram_tile_window_step = + (PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{}), + 0) + : is_bq_col_major ? make_array(KPerBlockBQ, 0) + : make_array(0, KPerBlockBQ); // DRAM prefetch (global read 0) Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); @@ -457,6 +460,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV return c_block_tile; } }; + // Overload for PreshuffleQuant = true template @@ -464,7 +468,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV const BDramBlockWindowTmp& b_dram_block_window_tmp, const BQDramBlockWindowTmp& bq_dram_block_window_tmp, index_t num_loop, - void* p_smem) const + void* p_smem, + index_t n = 0) const { return PipelineImpl{}.template operator()( a_dram_block_window_tmp, @@ -472,6 +477,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV b_dram_block_window_tmp, [](const BDataType& b) { return b; }, bq_dram_block_window_tmp, + n, num_loop, p_smem); } @@ -502,7 +508,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV index_t num_loop, bool has_hot_loop, TailNumber tail_number, - void* p_smem) const + void* p_smem, + index_t n = 0) const { const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) { constexpr bool hot_loop = has_hot_loop_.value; @@ -513,6 +520,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV b_dram_block_window_tmp, [](const BDataType& b) { return b; }, bq_dram_block_window_tmp, + n, // dummy value, won't be used num_loop, p_smem); }; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index 43dbf95941..6cd8dc3e0f 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -171,7 +171,8 @@ template + index_t XPerQ, + bool PreshuffleQuant = false> struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern { static constexpr index_t warp_size = get_warp_size(); @@ -213,52 +214,71 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding /// @return A static tile distribution encoding for the BQ scale tensor CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() { - if constexpr(XPerQ < WarpGemm::kN) + if constexpr(PreshuffleQuant) { - // Case 1: Fine-grained - multiple quantization scales within a single warp - constexpr index_t Y = YPerTile; // Full Y dimension of tile - constexpr index_t YR = 1; // No Y replication needed - constexpr index_t X0 = NIterPerWarp; // Iterations per warp in N-dim - constexpr index_t X1 = NWarps; // Number of warps in N-dim - constexpr index_t X2 = WarpGemm::kN / XPerQ; // Number of scales per warp - constexpr index_t XR = XPerQ; // Elements per quantization group - - static_assert(X0 * X1 * X2 == XPerTile, "X0, X1, X2 must cover the blocktile along X."); + constexpr index_t X1 = warp_size; + constexpr index_t X0 = XPerTile / warp_size; + constexpr index_t Y1 = NWarps; + constexpr index_t Y0 = YPerTile / Y1; return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0, 2, 0>>, - tuple, sequence<1, 2, 2>>, - sequence<2, 1>, - sequence<0, 0>>{}); - } - else if constexpr(XPerQ <= WarpGemm::kN * NWarps) - { - // Case 2: Medium-grained - one quantization scale per warp - constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor - constexpr auto X1 = NWarps / XR; // Warps per unique scale - constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0>>, - tuple, sequence<2>>, - sequence<2, 1>, - sequence<0, 0>>{}); - } - else // XPerQ > WarpGemm::kN * NWarps - { - // Case 3: Coarse-grained - quantization group spans all warps - // All warps in N-dimension share the same quantization scale - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0>>, + tile_distribution_encoding, + tuple, sequence>, tuple, sequence<2>>, - sequence<2, 1>, + tuple, sequence<1>>, + sequence<1, 2>, sequence<0, 0>>{}); } + else + { + if constexpr(XPerQ < WarpGemm::kN) + { + // Case 1: Fine-grained - multiple quantization scales within a single warp + constexpr index_t Y = YPerTile; // Full Y dimension of tile + constexpr index_t YR = 1; // No Y replication needed + constexpr index_t X0 = NIterPerWarp; // Iterations per warp in N-dim + constexpr index_t X1 = NWarps; // Number of warps in N-dim + constexpr index_t X2 = WarpGemm::kN / XPerQ; // Number of scales per warp + constexpr index_t XR = XPerQ; // Elements per quantization group + + static_assert(X0 * X1 * X2 == XPerTile, + "X0, X1, X2 must cover the blocktile along X."); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2, 0>>, + tuple, sequence<1, 2, 2>>, + sequence<2, 1>, + sequence<0, 0>>{}); + } + else if constexpr(XPerQ <= WarpGemm::kN * NWarps) + { + // Case 2: Medium-grained - one quantization scale per warp + constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor + constexpr auto X1 = NWarps / XR; // Warps per unique scale + constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<2, 1>, + sequence<0, 0>>{}); + } + else // XPerQ > WarpGemm::kN * NWarps + { + // Case 3: Coarse-grained - quantization group spans all warps + // All warps in N-dimension share the same quantization scale + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<2, 1>, + sequence<0, 0>>{}); + } + } } }; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index 825c86b0a1..b92c9ee1fd 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -68,6 +68,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV using Base::m_preload; + static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; static constexpr index_t KPerBlockBQ = integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK); static constexpr index_t QScalesPerBlockRow = @@ -106,6 +107,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV const AElementFunction& a_element_func, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t n, index_t num_loop, void* p_smem_ping, void* p_smem_pong) const @@ -236,7 +238,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV // BQ DRAM window for load auto bq_copy_dram_window = make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + bq_dram_block_window_tmp.get_window_lengths(), bq_dram_block_window_tmp.get_window_origin(), PipelinePolicy::template MakeBQDramTileDistribution()); @@ -269,8 +271,17 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV BQBlockTile bq_block_tile, bq_block_tile_2; bq_block_tile = load_tile(bq_copy_dram_window); // move BQ to tile 1 - move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0}); - + if constexpr(PreshuffleQuant) + { + move_tile_window(bq_copy_dram_window, + {ck_tile::integer_least_multiple(n, kNPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{}), + 0}); + } + else + { + move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0}); + } // Prefill A0 auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window_ping, a_block_tile_tmp); @@ -318,7 +329,17 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); bq_block_tile_2 = load_tile(bq_copy_dram_window); - move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0}); + if constexpr(PreshuffleQuant) + { + move_tile_window(bq_copy_dram_window, + {ck_tile::integer_least_multiple(n, kNPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{}), + 0}); + } + else + { + move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0}); + } // Prefill A(2i+1) a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); @@ -360,7 +381,17 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); bq_block_tile = load_tile(bq_copy_dram_window); - move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0}); + if constexpr(PreshuffleQuant) + { + move_tile_window(bq_copy_dram_window, + {ck_tile::integer_least_multiple(n, kNPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{}), + 0}); + } + else + { + move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0}); + } // Prefill A(2i+2) a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); @@ -448,6 +479,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV return c_block_tile; } + // Replace lines 485-526 with a single optimized operator: template @@ -456,14 +488,15 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV const BQDramBlockWindowTmp& bq_dram_block_window_tmp, index_t num_loop, void* p_smem_ping, - void* p_smem_pong) const + void* p_smem_pong, + index_t n = 0) const // Default value for non-preshuffle case { - return operator()( a_dram_block_window_tmp, [](const ADataType& a) { return a; }, b_flat_dram_block_window_tmp, bq_dram_block_window_tmp, + n, num_loop, p_smem_ping, p_smem_pong); @@ -478,7 +511,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV index_t num_loop, TailNumber tail_number, void* p_smem_ping, - void* p_smem_pong) const + void* p_smem_pong, + index_t n = 0) const { const auto RunPipeline = [&](auto bool_val, auto tail_num_) { (void)bool_val; // Suppress unused parameter warning @@ -488,6 +522,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV [](const ADataType& a) { return a; }, b_flat_dram_block_window_tmp, bq_dram_block_window_tmp, + n, // dummy value, won't be used num_loop, p_smem_ping, p_smem_pong); diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 0be276de8d..5e610cb76b 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -113,6 +113,11 @@ struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBP static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; +struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleBDecode +{ + static constexpr bool PreshuffleQuant = true; +}; + template class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase> { @@ -436,7 +441,13 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase bq_shuffle_host = - ck_tile::shuffle_bq_permuteN(bq_bqk_bqn); + ck_tile::bq_permuteN(bq_bqk_bqn); + bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data()); + } + else if constexpr(GemmConfig::PreshuffleQuant) + { + ck_tile::HostTensor bq_shuffle_host = + ck_tile::shuffle_bq(&bq_bqk_bqn, GemmConfig::K_Tile / QuantGroupSize::kK); bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data()); } else diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp index 3ace9188cc..34bdf4ea38 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp @@ -111,7 +111,12 @@ using BPreshuffleBQuantTypes = ::testing::Types< std::tuple, std::tuple, std::tuple, - std::tuple + std::tuple, + + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on