[CK_TILE] ABQuant New Preshuffle (#3638)

* Refactor

* Gemm quant improvement

* Change preshuffle

* Fix

* Fix grouped gemm ut

* Fix

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Yi DING
2026-01-28 15:46:49 +08:00
committed by GitHub
parent 91e32f305f
commit 8e3d84aba3
32 changed files with 182 additions and 213 deletions

View File

@@ -693,13 +693,13 @@ struct QuantGemmKernel
{
if constexpr(PreshuffleB)
{
index_t kFlatK =
GemmPipeline::flatKPerWarp *
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
constexpr auto warp_k = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
index_t kFlatKSplit = GemmPipeline::flatKPerWarp * (k_size / warp_k);
index_t kFlatK = GemmPipeline::flatKPerWarp * (kargs.K / warp_k);
index_t kFlatN = kargs.N * kargs.K / kFlatK;
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatN, kFlatKSplit),
make_tuple(kFlatK, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});

View File

@@ -52,11 +52,13 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
using BDataType = typename Problem::BDataType;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
constexpr index_t KBPerLoad =
min(GetKBPerLoad<Problem>(), 16 / static_cast<index_t>(sizeof(BDataType)));
#if defined(__gfx11__)
constexpr index_t KRepeatInWave = 2;
#else
@@ -64,8 +66,8 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel
#endif
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
constexpr index_t KRepeat = GetKBPerLoad<Problem>() / KBPerLoad;
static_assert(TileShape::flatKPerWarp == KRepeat * KThdPerWave * KBPerLoad, "wrong");
constexpr index_t NBPerLoad = 1;
constexpr index_t NThdPerWave = 1;
@@ -98,13 +100,23 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel
typename Problem::ADataType,
typename Problem::BDataType>;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
using BDataType = typename Problem::BDataType;
constexpr index_t KLaneBytes =
KLane / numeric_traits<BDataType>::PackedSize * sizeof(BDataType);
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
BTypeToUse,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
Problem::TransposeC,
false,
false,
NumAccess>;
// TODO : Use a custom block policy for AsBrCr
using BlockGemmPolicy =