[CK TILE GEMM] Fixed the regression issue with transpose C in Quant Gemm (#2819)

The numerical error was introduced after merging row/col quant. And it is fixed.

[ROCm/composable_kernel commit: 2ed39f8d91]
This commit is contained in:
Cong Ma
2025-09-11 12:38:16 -06:00
committed by GitHub
parent a7550e6499
commit 3790ed04c6
2 changed files with 8 additions and 8 deletions

View File

@@ -57,10 +57,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
const ck_tile::index_t K_split =
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
constexpr bool transposed_warp_gemm = false;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
@@ -128,7 +127,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
transposed_warp_gemm,
transpose_c,
ck_tile::memory_operation_enum::set>>;
using Kernel =
ck_tile::QuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, QuantMode>;

View File

@@ -358,10 +358,11 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
if constexpr(Traits::PreshuffleQuant)
{
static_assert(false,
"It is not supported yet to enable both Preshuffle and "
"TransposeC.");
if constexpr(Traits::TransposeC) // transposed C
{
static_assert(false,
"It is not supported yet to enable both Preshuffle.");
// TODO:
// A new tile distribution is needed for the Preshuffle and
// Transpose combination. For instance, with mnk at 16x16x32, lanes
@@ -455,7 +456,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
}
else
{
if(Traits::TransposeC) // transposed C
if constexpr(Traits::TransposeC) // transposed C
{
constexpr index_t reg_offset = mIter * Traits::AQPerBlock + kQScale;
constexpr auto tbuf_offset = number<