[CK TILE] Fix bugs in AQuant preshuffle (#2700)

* [CK TILE] Fix bugs in AQuant preshuffle

- Make Aquant works with block Mx64x256. `M` could be 16, 32, 64
- Make Aquant works with warp 16x16x32 and 32x32x16.

* [CK TILE] Rename Preshuffle to PreshuffleQuant

The new name, PreshuffleQuant, explicitly states the function's purpose:
to preshuffle the quantization matrix.

* [CK TILE Block Scale] Use GemmConfig to save tile properties

- Remove specialization of GemmQuantTypeConfig
- Pass GemmConfig around which contains tile properties. Stop using hard
  coded tile properties in `gemm_calc_aquant()`

* [CK TILE Block Scale] Rename GemmConfig used in block scale

    - Remove unused GemmConfig
    - Rename GemmConfig used in block scale

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Cong Ma
2025-08-27 01:05:54 -06:00
committed by GitHub
parent 95e4a4efcb
commit 245467f359
12 changed files with 266 additions and 1069 deletions

View File

@@ -37,23 +37,23 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
using BlockGemmShape = typename Problem::BlockGemmShape;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
constexpr bool Preshuffle = Problem::Traits::Preshuffle;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
if constexpr(Preshuffle)
if constexpr(PreshuffleQuant)
{
using TileEncodingPattern =
TileDistributionEncodingPatternAQ<BlockGemmShape,
@@ -64,7 +64,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
WarpGemm::kM * KPerBlockAQ, get_warp_size()),
KPerBlockAQ,
VecLoadSize,
Preshuffle>;
PreshuffleQuant>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
@@ -77,7 +77,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
KPerBlockAQ,
KPerBlockAQ,
VecLoadSize,
Preshuffle>;
PreshuffleQuant>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}

View File

@@ -7,6 +7,7 @@
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
@@ -133,7 +134,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool Preshuffle = Problem::Traits::Preshuffle;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
@@ -235,6 +236,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
index_t m,
index_t num_loop,
void* p_smem) const
{
@@ -311,9 +313,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// only row_major for AQ
constexpr AQDramTileWindowStep aq_dram_tile_window_step =
Preshuffle ? make_array(MPerBlock / BlockGemm::WarpGemm::kM, 0)
: make_array(0, KPerBlockAQ);
const AQDramTileWindowStep aq_dram_tile_window_step =
PreshuffleQuant ? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
BlockGemm::WarpGemm::kM,
0)
: make_array(0, KPerBlockAQ);
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
@@ -458,6 +462,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
index_t m,
index_t num_loop,
void* p_smem) const
{
@@ -467,6 +472,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
aq_dram_block_window_tmp,
m,
num_loop,
p_smem);
}

View File

@@ -52,7 +52,7 @@ template <typename BlockGemmShape,
index_t XPerTile,
index_t KPerBlockAQ,
index_t VecSize,
bool Preshuffle>
bool PreshuffleQuant>
struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern
{
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
@@ -72,20 +72,20 @@ struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPatter
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
if constexpr(Preshuffle)
if constexpr(PreshuffleQuant)
{
// # of elements per thread
constexpr index_t X2 = KPerBlockAQ;
constexpr index_t X1 = warp_size / X2;
static_assert(XPerTile >= warp_size && XPerTile % warp_size == 0);
constexpr index_t X1 = warp_size;
constexpr index_t X0 = XPerTile / warp_size;
constexpr index_t Y1 = MWarps;
constexpr index_t Y0 = YPerTile / Y1;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>,
tuple<sequence<Y0, Y1>, sequence<X0, X1, X2>>,
tuple<sequence<1, 0>, sequence<2, 2>>,
tuple<sequence<1, 0>, sequence<1, 2>>,
tuple<sequence<Y0, Y1>, sequence<X0, X1>>,
tuple<sequence<1, 0>, sequence<2>>,
tuple<sequence<1, 0>, sequence<1>>,
sequence<1, 2>,
sequence<0, 0>>{});
}

View File

@@ -10,7 +10,7 @@ namespace ck_tile {
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
bool Preshuffle_,
bool PreshuffleQuant_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
@@ -30,7 +30,7 @@ struct TileGemmAQuantTraits
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = Preshuffle_;
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
};
} // namespace ck_tile