From cdba8b787c920141745aa186acd58cb1c8210aac Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Wed, 11 Feb 2026 00:04:44 -0700 Subject: [PATCH] [CK TILE] fix numerical errors of preshuffle_b (#4354) This pull request introduces several improvements and fixes related to quantized grouped GEMM (General Matrix Multiply) pipelines and their supporting utilities. # The numerical issue ## Steps to reproduce ```bash Run ./bin/tile_example_gemm_weight_preshuffle -prec=fp8 ./bin/tile_example_gemm_weight_preshuffle -prec=int4 ``` # Solution The main changes address type correctness, improve data layout and shuffling logic, and expand test coverage to better validate different GEMM configurations. **Key changes include:** ### Data layout and shuffling logic * Refactored the logic in `shuffle_b_permuteN` to use `constexpr` variables for `KLane` and `ItemsPerAccess`, simplifying tile view construction and correcting the permutation order for improved efficiency and correctness (`tensor_shuffle_utils.hpp`). * Fixed the calculation of `KLaneBytes` in weight preshuffle pipeline policies to account for internal data type conversion (e.g., from `pk_int4_t` to `fp8`), ensuring accurate memory access and alignment in quantized GEMM policies (`wp_pipeline_agmem_bgmem_creg_base_policy.hpp`, `gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp`). [[1]](diffhunk://#diff-93f16cd76e6e24404777e682a5ac8e039913ddd6a438c7efd61fdda42276e4efL274-R275) [[2]](diffhunk://#diff-9c3d0fc3c014feed435bfd93ba1f8f9fb3e054dcc322deada3addf70bee5a58cL100-R105) ### Test infrastructure enhancements * Unit tests did not catch this issue since there were no tests for fp8. Added new configuration structs (`config_mn_16x16`, `config_mn_32x32`) to support additional GEMM tile shapes and updated tests to run with these configurations for broader coverage (`test_gemm_pipeline_util.hpp`). [[1]](diffhunk://#diff-5a5962b2c4aa7f6a87d1d6201ad383135e30df13b42654e997d870d57420d5b8R86-R103) [[2]](diffhunk://#diff-5a5962b2c4aa7f6a87d1d6201ad383135e30df13b42654e997d870d57420d5b8L255-R269) Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- .../17_grouped_gemm/abquant_grouped_gemm.cpp | 13 ++++----- include/ck_tile/host/tensor_shuffle_utils.hpp | 20 +++++--------- ..._pipeline_agmem_bgmem_creg_base_policy.hpp | 27 +++++++++---------- .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 4 +++ .../gemm_abquant_pipeline_ag_bg_cr_v3.hpp | 4 +-- ..._abquant_pipeline_ag_bg_cr_base_policy.hpp | 11 +++++--- .../test_gemm_pipeline_util.hpp | 18 +++++++++++-- 7 files changed, 55 insertions(+), 42 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp index 703751b760..28b3884d0f 100644 --- a/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp @@ -75,8 +75,8 @@ float grouped_gemm_abquant(const std::vector& gemm_descs, ck_tile::GemmPipelineProblem; using BaseGemmPipeline = - GemmQuantConfig::template BaseGemmPipeline; + typename GemmQuantConfig::template BaseGemmPipeline; const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; @@ -108,8 +108,8 @@ float grouped_gemm_abquant(const std::vector& gemm_descs, tail_number_v>; using GemmPipeline = - GemmQuantConfig::template GemmPipeline; + typename GemmQuantConfig::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem; - using GemmPipeline = GemmQuantConfig::template GemmPipeline; + using GemmPipeline = + typename GemmQuantConfig::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem& t, const GemmConfig& gemmC } else { - int divisor = 1; - if(ck_tile::is_gfx11_supported()) - { - divisor = 1; - } - else - { - assert(is_wave32() == false); - divisor = get_warp_size() / gemmConfig.N_Warp_Tile; - } + constexpr int KLane = ck_tile::get_warp_size() / GemmConfig::N_Warp_Tile; + constexpr int ItemsPerAccess = + std::min(16 / static_cast(sizeof(T)), GemmConfig::K_Warp_Tile / KLane); ck_tile::HostTensor t_view({n_ / gemmConfig.N_Tile, gemmConfig.N_Warp, gemmConfig.N_Warp_Tile, NRepeat, - k_ / gemmConfig.K_Warp_Tile, - divisor, - gemmConfig.K_Warp_Tile / divisor}); + k_ / ItemsPerAccess, + ItemsPerAccess}); std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 2, 5}); } } diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 0044b412ec..4903f8e501 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -271,20 +271,19 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy constexpr index_t WaveSize = get_warp_size(); constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize; - using BDataType = typename Problem::BDataType; - constexpr index_t KLaneBytes = - KLane / numeric_traits::PackedSize * sizeof(BDataType); - constexpr auto NumAccess = static_cast(max(1, KLaneBytes / 16)); - using WarpGemm = WarpGemmDispatcher; + // When BDataType is pk_int4_t, it is internally converted to fp8 for computation. + constexpr index_t KLaneBytes = KLane * sizeof(BTypeToUse); + constexpr auto NumAccess = static_cast(max(1, KLaneBytes / 16)); + using WarpGemm = WarpGemmDispatcher; using BlockWeightPreshufflePolicy = BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy struct Dispatcher { using Ty template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; }; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp index cfd12313e8..cb36d02aa5 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp @@ -65,8 +65,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3())>; // A/B DataType gets converted from PkInt4/PkFp4 during loading - using OverrideADataType = BlockGemm::OverrideADataType; - using OverrideBDataType = BlockGemm::OverrideBDataType; + using OverrideADataType = typename BlockGemm::OverrideADataType; + using OverrideBDataType = typename BlockGemm::OverrideBDataType; static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t MPerBlock = BlockGemmShape::kM; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp index f136b86314..36d8560543 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp @@ -97,10 +97,13 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel constexpr index_t WaveSize = get_warp_size(); constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize; - using BDataType = typename Problem::BDataType; - constexpr index_t KLaneBytes = - KLane / numeric_traits::PackedSize * sizeof(BDataType); - constexpr auto NumAccess = static_cast(max(1, KLaneBytes / 16)); + + // When BDataType is pk_int4_t, it is internally converted to fp8 for computation. + using BTypeToUse = mixed_prec_compute_type_from_input_t; + constexpr index_t KLaneBytes = KLane * sizeof(BTypeToUse); + constexpr auto NumAccess = static_cast(max(1, KLaneBytes / 16)); using WarpGemm = WarpGemmDispatcher +struct config_mn_32x32 : public config +{ static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); +}; + +template +struct config_mn_16x16 : public config +{ + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; template @@ -252,7 +264,9 @@ class TestCkTileGemmPipeline : public ::testing::Test RunSingle, PadM, PadN, PadK, Preshuffle>( M, N, K, StrideA, StrideB, StrideC, kb); #else - RunSingle, PadM, PadN, PadK, Preshuffle>( + RunSingle, PadM, PadN, PadK, Preshuffle>( + M, N, K, StrideA, StrideB, StrideC, kb); + RunSingle, PadM, PadN, PadK, Preshuffle>( M, N, K, StrideA, StrideB, StrideC, kb); #endif }