From 3790ed04c686d27eb8949699484dbc8cd545cc71 Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Thu, 11 Sep 2025 12:38:16 -0600 Subject: [PATCH] [CK TILE GEMM] Fixed the regression issue with transpose C in Quant Gemm (#2819) The numerical error was introduced after merging row/col quant. And it is fixed. [ROCm/composable_kernel commit: 2ed39f8d918f81f3b0ad450c26423d638a437d99] --- example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp | 9 ++++----- .../block/block_universal_gemm_as_aquant_bs_cr.hpp | 7 ++++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp index 35ffcf1d56..79c6cca6cb 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp @@ -57,10 +57,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - constexpr bool transposed_warp_gemm = false; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { constexpr bool has_hot_loop_v = has_hot_loop_.value; @@ -128,7 +127,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile, - transposed_warp_gemm, + transpose_c, ck_tile::memory_operation_enum::set>>; using Kernel = ck_tile::QuantGemmKernel; diff --git a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index b41f01b951..182d9251b1 100644 --- a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -358,10 +358,11 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase if constexpr(Traits::PreshuffleQuant) { + static_assert(false, + "It is not supported yet to enable both Preshuffle and " + "TransposeC."); if constexpr(Traits::TransposeC) // transposed C { - static_assert(false, - "It is not supported yet to enable both Preshuffle."); // TODO: // A new tile distribution is needed for the Preshuffle and // Transpose combination. For instance, with mnk at 16x16x32, lanes @@ -455,7 +456,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase } else { - if(Traits::TransposeC) // transposed C + if constexpr(Traits::TransposeC) // transposed C { constexpr index_t reg_offset = mIter * Traits::AQPerBlock + kQScale; constexpr auto tbuf_offset = number<