Preshuffle AQ matrix in block scale gemm (#2624)

* Preshuffle AQ matrix in block scale gemm

* turns the output to fp16. Increase the repetition time.

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Cong Ma
2025-08-12 22:32:51 -06:00
committed by GitHub
parent 0f42a92fc1
commit 452791a3ba
13 changed files with 667 additions and 228 deletions

View File

@@ -38,12 +38,9 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using YPerTile = number<MPerBlock>;
using XPerTile = number<KPerBlockAQ>;
auto aq_copy_dram_window =
make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(YPerTile(), XPerTile()),
aq_dram_block_window_tmp.get_window_lengths(),
aq_dram_block_window_tmp.get_window_origin(),
Policy::template MakeAQDramTileDistribution<Problem>());
return aq_copy_dram_window;

View File

@@ -42,6 +42,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
constexpr bool Preshuffle = Problem::Traits::Preshuffle;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
@@ -52,14 +53,34 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
false>;
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using TileEncodingPattern = TileDistributionEncodingPatternAQ<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock,
KPerBlockAQ,
VecLoadSize>;
if constexpr(Preshuffle)
{
using TileEncodingPattern =
TileDistributionEncodingPatternAQ<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock / WarpGemm::kM,
ck_tile::integer_least_multiple(
WarpGemm::kM * KPerBlockAQ, get_warp_size()),
KPerBlockAQ,
VecLoadSize,
Preshuffle>;
return TileEncodingPattern::Make2DStaticTileDistribution();
return TileEncodingPattern::Make2DStaticTileDistribution();
}
else
{
using TileEncodingPattern = TileDistributionEncodingPatternAQ<BlockGemmShape,
WarpGemm,
BlockSize,
MPerBlock,
KPerBlockAQ,
KPerBlockAQ,
VecLoadSize,
Preshuffle>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
}
template <typename Problem>

View File

@@ -7,7 +7,6 @@
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
@@ -134,6 +133,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool Preshuffle = Problem::Traits::Preshuffle;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
@@ -254,9 +254,6 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)");
static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
"Aq block window has incorrect lengths for defined AqLayout!");
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
@@ -312,8 +309,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// only row_major for AQ
constexpr AQDramTileWindowStep aq_dram_tile_window_step =
is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ);
Preshuffle ? make_array(MPerBlock / BlockGemm::WarpGemm::kM, 0)
: make_array(0, KPerBlockAQ);
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);

View File

@@ -50,10 +50,11 @@ template <typename BlockGemmShape,
index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize>
index_t KPerBlockAQ,
index_t VecSize,
bool Preshuffle>
struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
@@ -69,26 +70,46 @@ struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPatter
// KWarps > 1 isn't supported
static_assert(KWarps == 1);
// # of elements per thread
static constexpr index_t X = XPerTile;
static constexpr index_t Y0 = 1;
static constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
static constexpr index_t Y2 = MWarps;
static constexpr index_t Y3 = WarpGemm::kM;
static_assert(Y3 >= WarpGemm::kM, "Scales for all rows must be available within the warp.");
static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
"Y0, Y1, Y2, Y3 must cover the blocktile along Y.");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>,
tuple<sequence<Y0, Y1, Y2, Y3>, sequence<X>>,
tuple<sequence<1, 0>, sequence<1, 1>>,
tuple<sequence<2, 0>, sequence<0, 3>>,
sequence<1, 2>,
sequence<1, 0>>{});
if constexpr(Preshuffle)
{
// # of elements per thread
constexpr index_t X2 = KPerBlockAQ;
constexpr index_t X1 = warp_size / X2;
constexpr index_t X0 = XPerTile / warp_size;
constexpr index_t Y1 = MWarps;
constexpr index_t Y0 = YPerTile / Y1;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>,
tuple<sequence<Y0, Y1>, sequence<X0, X1, X2>>,
tuple<sequence<1, 0>, sequence<2, 2>>,
tuple<sequence<1, 0>, sequence<1, 2>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
else
{
// # of elements per thread
constexpr index_t X = XPerTile;
constexpr index_t Y0 = 1;
constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
constexpr index_t Y2 = MWarps;
constexpr index_t Y3 = WarpGemm::kM;
static_assert(Y3 >= WarpGemm::kM,
"Scales for all rows must be available within the warp.");
static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
"Y0, Y1, Y2, Y3 must cover the blocktile along Y.");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>,
tuple<sequence<Y0, Y1, Y2, Y3>, sequence<X>>,
tuple<sequence<1, 0>, sequence<1, 1>>,
tuple<sequence<2, 0>, sequence<0, 3>>,
sequence<1, 2>,
sequence<1, 0>>{});
}
}
};

View File

@@ -10,6 +10,7 @@ namespace ck_tile {
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
bool Preshuffle_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
@@ -29,6 +30,7 @@ struct TileGemmAQuantTraits
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = Preshuffle_;
};
} // namespace ck_tile