[CK TILE] Fix bugs in AQuant preshuffle (#2700)

* [CK TILE] Fix bugs in AQuant preshuffle

- Make Aquant works with block Mx64x256. `M` could be 16, 32, 64
- Make Aquant works with warp 16x16x32 and 32x32x16.

* [CK TILE] Rename Preshuffle to PreshuffleQuant

The new name, PreshuffleQuant, explicitly states the function's purpose:
to preshuffle the quantization matrix.

* [CK TILE Block Scale] Use GemmConfig to save tile properties

- Remove specialization of GemmQuantTypeConfig
- Pass GemmConfig around which contains tile properties. Stop using hard
  coded tile properties in `gemm_calc_aquant()`

* [CK TILE Block Scale] Rename GemmConfig used in block scale

    - Remove unused GemmConfig
    - Rename GemmConfig used in block scale

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Cong Ma
2025-08-27 01:05:54 -06:00
committed by GitHub
parent 95e4a4efcb
commit 245467f359
12 changed files with 266 additions and 1069 deletions

View File

@@ -157,7 +157,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
static constexpr index_t KPack = WarpGemm::kKPerThread;
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
static constexpr bool Preshuffle = Problem::Traits::Preshuffle;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
};
public:
@@ -357,7 +357,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
}
});
if constexpr(Traits::Preshuffle)
if constexpr(Traits::PreshuffleQuant)
{
// A view is created on top of the preshuffled AQ, where each row of the
// view is composed of a row from a warp tile within an AQ block tile.
@@ -392,12 +392,27 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
// Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1, 0]
// from thread 1, ..., and AQ_tile[3, 0] from thread 3.
auto pull_from_lane =
((threadIdx.x & (warp_size - 1)) / Traits::WarpGemm::kN *
kTileRowsOfCPerThread +
c_row) *
Traits::QScalesPerBlockRow +
kQScale;
decltype(threadIdx.x) pull_from_lane = 0;
if constexpr(WarpGemm::kM == 16)
{
pull_from_lane = (__lane_id() / Traits::WarpGemm::kN *
kTileRowsOfCPerThread +
c_row) *
Traits::QScalesPerBlockRow +
kQScale;
}
else if constexpr(WarpGemm::kM == 32)
{
pull_from_lane = (__lane_id() / Traits::WarpGemm::kN *
kTileRowsOfCPerThread +
((c_row >> 2) << 3) + (c_row & 0b11)) *
Traits::QScalesPerBlockRow +
kQScale;
}
else
{
static_assert(false, "WarpGemm::kM is not 16 nor 32.");
}
auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter];
// cross lane ops

View File

@@ -99,15 +99,15 @@ struct AQuantGemmKernelArgs
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct AQuantGemmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using AQLayout = remove_cvref_t<typename GemmPipeline::AQLayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
static constexpr bool Preshuffle = GemmPipeline::Preshuffle;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using AQLayout = remove_cvref_t<typename GemmPipeline::AQLayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
static constexpr bool PreshuffleQuant = GemmPipeline::PreshuffleQuant;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using AQDataType = remove_cvref_t<typename GemmPipeline::AQDataType>;
@@ -422,9 +422,9 @@ struct AQuantGemmKernel
ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
const auto aq_merge_pad1_desc = transform_tensor_descriptor(
aq_pad1_desc,
make_tuple(make_merge_transform(make_tuple(wave_tile_count_x, aq_y)),
make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
make_pass_through_transform(pad_wave_size)),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
@@ -432,7 +432,7 @@ struct AQuantGemmKernel
const auto& aq_tensor_view = [&]() {
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
if constexpr(Preshuffle)
if constexpr(PreshuffleQuant)
{
return make_preshuffled_aq_tensor_view();
}
@@ -599,10 +599,8 @@ struct AQuantGemmKernel
}
template <typename PadView>
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
const AQuantGemmKernelArgs& kargs,
const index_t i_m,
const index_t i_n)
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& aq_pad_view = views.at(I1);
@@ -628,24 +626,27 @@ struct AQuantGemmKernel
const auto& aq_block_window = [&]() {
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
if constexpr(Preshuffle)
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto block_k = TilePartitioner::KPerBlock;
constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
constexpr auto aqk_per_block =
TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize;
if constexpr(PreshuffleQuant)
{
constexpr auto tile_window_width = get_warp_size();
constexpr auto tile_window_height =
TilePartitioner::MPerBlock / TilePartitioner::BlockGemmShape::WarpTile::at(I0);
auto block_m_idx = i_m / TilePartitioner::MPerBlock;
constexpr auto tile_window_width =
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
constexpr auto tile_window_height = block_m / warp_m;
auto block_m_idx = i_m / block_m;
return make_tile_window(
aq_pad_view,
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_m_idx * kargs.K / TilePartitioner::BlockGemmShape::BlockTile::at(I2),
0});
{block_m_idx * tile_window_height, 0});
}
else
{
return make_tile_window(
aq_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
make_tuple(number<block_m>{}, number<block_k / GemmPipeline::QuantGroupSize>{}),
{i_m, 0});
}
}();
@@ -706,8 +707,7 @@ struct AQuantGemmKernel
a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
MakeGemmTileWindows(gemm_pad_views, kargs, block_idx_m, block_idx_n);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
@@ -718,7 +718,7 @@ struct AQuantGemmKernel
const auto& b_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0);
a_block_window, b_block_window, aq_block_window, kargs.M, num_loop, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);

View File

@@ -37,23 +37,23 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
using BlockGemmShape = typename Problem::BlockGemmShape;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
constexpr bool Preshuffle = Problem::Traits::Preshuffle;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
if constexpr(Preshuffle)
if constexpr(PreshuffleQuant)
{
using TileEncodingPattern =
TileDistributionEncodingPatternAQ<BlockGemmShape,
@@ -64,7 +64,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
WarpGemm::kM * KPerBlockAQ, get_warp_size()),
KPerBlockAQ,
VecLoadSize,
Preshuffle>;
PreshuffleQuant>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
@@ -77,7 +77,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
KPerBlockAQ,
KPerBlockAQ,
VecLoadSize,
Preshuffle>;
PreshuffleQuant>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}

View File

@@ -7,6 +7,7 @@
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
@@ -133,7 +134,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool Preshuffle = Problem::Traits::Preshuffle;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
@@ -235,6 +236,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
index_t m,
index_t num_loop,
void* p_smem) const
{
@@ -311,9 +313,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// only row_major for AQ
constexpr AQDramTileWindowStep aq_dram_tile_window_step =
Preshuffle ? make_array(MPerBlock / BlockGemm::WarpGemm::kM, 0)
: make_array(0, KPerBlockAQ);
const AQDramTileWindowStep aq_dram_tile_window_step =
PreshuffleQuant ? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
BlockGemm::WarpGemm::kM,
0)
: make_array(0, KPerBlockAQ);
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
@@ -458,6 +462,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
index_t m,
index_t num_loop,
void* p_smem) const
{
@@ -467,6 +472,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
aq_dram_block_window_tmp,
m,
num_loop,
p_smem);
}

View File

@@ -52,7 +52,7 @@ template <typename BlockGemmShape,
index_t XPerTile,
index_t KPerBlockAQ,
index_t VecSize,
bool Preshuffle>
bool PreshuffleQuant>
struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern
{
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
@@ -72,20 +72,20 @@ struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPatter
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
if constexpr(Preshuffle)
if constexpr(PreshuffleQuant)
{
// # of elements per thread
constexpr index_t X2 = KPerBlockAQ;
constexpr index_t X1 = warp_size / X2;
static_assert(XPerTile >= warp_size && XPerTile % warp_size == 0);
constexpr index_t X1 = warp_size;
constexpr index_t X0 = XPerTile / warp_size;
constexpr index_t Y1 = MWarps;
constexpr index_t Y0 = YPerTile / Y1;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>,
tuple<sequence<Y0, Y1>, sequence<X0, X1, X2>>,
tuple<sequence<1, 0>, sequence<2, 2>>,
tuple<sequence<1, 0>, sequence<1, 2>>,
tuple<sequence<Y0, Y1>, sequence<X0, X1>>,
tuple<sequence<1, 0>, sequence<2>>,
tuple<sequence<1, 0>, sequence<1>>,
sequence<1, 2>,
sequence<0, 0>>{});
}

View File

@@ -10,7 +10,7 @@ namespace ck_tile {
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
bool Preshuffle_,
bool PreshuffleQuant_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
@@ -30,7 +30,7 @@ struct TileGemmAQuantTraits
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = Preshuffle_;
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
};
} // namespace ck_tile