[CK TILE] Fix bugs in AQuant preshuffle (#2700)

* [CK TILE] Fix bugs in AQuant preshuffle

- Make Aquant works with block Mx64x256. `M` could be 16, 32, 64
- Make Aquant works with warp 16x16x32 and 32x32x16.

* [CK TILE] Rename Preshuffle to PreshuffleQuant

The new name, PreshuffleQuant, explicitly states the function's purpose:
to preshuffle the quantization matrix.

* [CK TILE Block Scale] Use GemmConfig to save tile properties

- Remove specialization of GemmQuantTypeConfig
- Pass GemmConfig around which contains tile properties. Stop using hard
  coded tile properties in `gemm_calc_aquant()`

* [CK TILE Block Scale] Rename GemmConfig used in block scale

    - Remove unused GemmConfig
    - Rename GemmConfig used in block scale

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Cong Ma
2025-08-27 01:05:54 -06:00
committed by GitHub
parent 95e4a4efcb
commit 245467f359
12 changed files with 266 additions and 1069 deletions

View File

@@ -157,7 +157,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
static constexpr index_t KPack = WarpGemm::kKPerThread;
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
static constexpr bool Preshuffle = Problem::Traits::Preshuffle;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
};
public:
@@ -357,7 +357,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
}
});
if constexpr(Traits::Preshuffle)
if constexpr(Traits::PreshuffleQuant)
{
// A view is created on top of the preshuffled AQ, where each row of the
// view is composed of a row from a warp tile within an AQ block tile.
@@ -392,12 +392,27 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
// Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1, 0]
// from thread 1, ..., and AQ_tile[3, 0] from thread 3.
auto pull_from_lane =
((threadIdx.x & (warp_size - 1)) / Traits::WarpGemm::kN *
kTileRowsOfCPerThread +
c_row) *
Traits::QScalesPerBlockRow +
kQScale;
decltype(threadIdx.x) pull_from_lane = 0;
if constexpr(WarpGemm::kM == 16)
{
pull_from_lane = (__lane_id() / Traits::WarpGemm::kN *
kTileRowsOfCPerThread +
c_row) *
Traits::QScalesPerBlockRow +
kQScale;
}
else if constexpr(WarpGemm::kM == 32)
{
pull_from_lane = (__lane_id() / Traits::WarpGemm::kN *
kTileRowsOfCPerThread +
((c_row >> 2) << 3) + (c_row & 0b11)) *
Traits::QScalesPerBlockRow +
kQScale;
}
else
{
static_assert(false, "WarpGemm::kM is not 16 nor 32.");
}
auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter];
// cross lane ops