[CK_TILE] B matrix 2D block scale gemm (#3074)

* Refactor quant group size to be configurable for M/N/K, not just K

* add some asserts for configurations not implemented

* start setting of group size for N dimension

* enable 2d for reference quant gemm

* WIP: trying to figure out tile dstr and/or indexing for scale matrix

* WIP

* Fix handling of n dim blocks in tile windows etc

* remove commented code and enable all tests again

* fix formatting

* Add more specialized tile distributions

* Enable NWarps replication for bquant tile dstr

* fix formatting

* fix format

* Fix some issues from the merge

* fix formatting

* one more fix to tile dstr, and revert debug initialization

* Remove commented code

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* simplify conditions that are needed for tile distributions

* only enable the working group sizes in tests

* fix formatting

* Update tile distribution for 2D bquant

* add some documentation and 2d block scale example

* fix formatting

* Add in Changlog and restructure the quant 2d example

* fix CMake

* support the change for blockscale 2d

* fix the test file

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Sami Remes
2025-11-03 00:49:20 +00:00
committed by GitHub
parent 73f637894d
commit 16e85cf179
24 changed files with 476 additions and 363 deletions

View File

@@ -10,7 +10,7 @@ namespace ck_tile {
// A is block window on shared memory
// BQ (scale tensor) is block distributed tensor.
// Consecutive kQuantGroupSize elements of B are quantized with a separate scale.
// Consecutive QuantGroupSize elements of B are quantized with a separate scale.
// B is block window on block distributed tensor.
// C is block distributed tensor
template <typename Problem_, typename BlockPolicy_>
@@ -24,6 +24,10 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
@@ -47,8 +51,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
static constexpr index_t NIterPerWarp =
@@ -58,13 +61,12 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
static constexpr auto MIter_2nd_last =
(MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
static constexpr index_t KPerBlockBQ = KPerBlock / kQuantGroupSize;
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
static constexpr index_t QScalesPerBlockRow =
(KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize;
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
static constexpr index_t QScalesPerWarpGemmRow =
(WG::kK + kQuantGroupSize - 1) / kQuantGroupSize;
integer_divide_ceil(WG::kK, QuantGroupSize::kK);
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read

View File

@@ -46,7 +46,7 @@ struct BlockGemmAQuantBase
// A is block window on shared memory
// AQ (scale tensor) is block distributed tensor.
// Consecutive kQuantGroupSize elements of A are quantized with a separate scale.
// Consecutive QuantGroupSize elements of A are quantized with a separate scale.
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_,
@@ -66,16 +66,16 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
// Threadblock GEMM tile size
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t AQPerBlock = KPerBlock / kQuantGroupSize;
static constexpr index_t AQPerBlock = KPerBlock / QuantGroupSize::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
@@ -101,20 +101,20 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr index_t QScalesPerBlockRow =
(KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize;
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
static constexpr index_t QScalesPerWarpGemmRow =
(WarpGemm::kK + kQuantGroupSize - 1) / kQuantGroupSize;
integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
static_assert(kQuantGroupSize % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of kQuantGroupSize");
static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
static_assert(QScalesPerWarpGemmRow == 1,
"Error! kQuantGroupSize shouldn't be smaller than WarpGemm::kK");
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
static_assert(KPerBlock / kQuantGroupSize > 0,
static_assert(KPerBlock / QuantGroupSize::kK > 0,
"Error! Each row of blockgemm should have a separate scale");
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,

View File

@@ -46,7 +46,7 @@ struct BlockGemmBQuantBase
// A is block window on shared memory
// BQ (scale tensor) is block distributed tensor.
// Consecutive kQuantGroupSize elements of B are quantized with a separate scale.
// Consecutive QuantGroupSize elements of B are quantized with a separate scale.
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_,
@@ -66,16 +66,18 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
// Threadblock GEMM tile size
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t BQPerBlock = KPerBlock / kQuantGroupSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NQPerBlock = NPerBlock / QuantGroupSize::kN;
static constexpr index_t KQPerBlock = KPerBlock / QuantGroupSize::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
@@ -101,20 +103,20 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr index_t QScalesPerBlockRow =
(KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize;
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
static constexpr index_t QScalesPerWarpGemmRow =
(WarpGemm::kK + kQuantGroupSize - 1) / kQuantGroupSize;
integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
static_assert(kQuantGroupSize % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of kQuantGroupSize");
static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
static_assert(QScalesPerWarpGemmRow == 1,
"Error! kQuantGroupSize shouldn't be smaller than WarpGemm::kK");
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
static_assert(KPerBlock / kQuantGroupSize > 0,
static_assert(KPerBlock / QuantGroupSize::kK > 0,
"Error! Each row of blockgemm should have a separate scale");
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
@@ -340,23 +342,17 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
}
});
// Need to multiply bquant with accumulated C
//
// The accumulated C tile has the standard distribution. For example
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
// [26,0], [27,0].
//
// These elements are in different rows, need to get the scale value
// for the corresponding row.
// Based on bquant's tile distribution, it can be inferred which
// lane holds the relevant scale. For example, the scales corresponding
// to the 16 elements held by lane 0 are held by lanes 0, 1, 2, 3, 8, 9,
// 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 respectively.
//
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.
constexpr index_t reg_offset = nIter * Traits::BQPerBlock + kQScale;
// Multiply bquant with accumulated C
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::QuantGroupSize::kN >= (NWarp * WarpGemm::kN))
return (nIter * NWarp * WarpGemm::kN) /
GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock +
kQScale;
else
{
return nIter * Traits::KQPerBlock + kQScale;
}
}();
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(

View File

@@ -685,9 +685,10 @@ struct QuantGemmKernel
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(kargs.QK_B, kargs.N),
make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
make_tuple(1, kargs.stride_BQ),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
@@ -831,10 +832,10 @@ struct QuantGemmKernel
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
constexpr auto aqk_per_block =
TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize;
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto tile_window_width =
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
constexpr auto tile_window_height = block_m / warp_m;
@@ -847,11 +848,12 @@ struct QuantGemmKernel
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto block_k = TilePartitioner::KPerBlock;
return make_tile_window(
aq_pad_view,
make_tuple(number<block_m>{}, number<block_k / GemmPipeline::QuantGroupSize>{}),
make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
{i_m, 0});
}
else if constexpr(kQuantType == QuantType::RowColQuant)
@@ -907,11 +909,12 @@ struct QuantGemmKernel
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
{0, i_n / QuantGroupSize::kN});
}
else
{

View File

@@ -18,6 +18,7 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
using BDataType = typename Base::BDataType;
using BLayout = typename Base::BLayout;
using BlockGemmShape = typename Base::BlockGemmShape;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
@@ -25,10 +26,9 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize;
static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize::kK;
static_assert(KPerBlock % QuantGroupSize == 0,
static_assert(KPerBlock % QuantGroupSize::kK == 0,
"KPerBlock must be a multiple of QuantGroupSize");
// Create DRAM tile window for AQ

View File

@@ -86,6 +86,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem<Prob
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
using I0 = number<0>;
using I1 = number<1>;
@@ -106,12 +110,11 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem<Prob
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
@@ -147,7 +150,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem<Prob
BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
concat('x', kPadM, kPadN, kPadK), "QuantGroupSize", QuantGroupSize,
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName(),
Scheduler == GemmPipelineScheduler::Interwave ? "interwave" : "intrawave"); // else Intrawave
// clang-format on
}
@@ -204,7 +207,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem<Prob
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "QuantGroupSize: " << QuantGroupSize << "\n"
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();

View File

@@ -22,7 +22,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::QuantGroupSize::kK;
static_assert(std::is_same_v<AQLayout, ck_tile::tensor_layout::gemm::RowMajor>);
return GetABQGlobalVectorLoadSize<Problem, AQDataType, MPerBlock, KPerBlockAQ>();
@@ -37,7 +37,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::QuantGroupSize::kK;
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
@@ -99,8 +99,8 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
static_assert(Problem::kQuantGroupSize % WarpTile::at(I2) == 0,
"KPerWarpGemm must be a multiple of kQuantGroupSize!");
static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0,
"KPerWarpGemm must be a multiple of QuantGroupSize::kK!");
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,

View File

@@ -91,6 +91,10 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
using I0 = number<0>;
using I1 = number<1>;
@@ -111,12 +115,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
@@ -152,7 +155,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
concat('x', kPadM, kPadN, kPadK), "QuantGroupSize", QuantGroupSize);
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
// clang-format on
}
@@ -208,7 +211,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "QuantGroupSize: " << QuantGroupSize << "\n"
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();

View File

@@ -18,6 +18,7 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
using BDataType = typename Base::BDataType;
using BLayout = typename Base::BLayout;
using BlockGemmShape = typename Base::BlockGemmShape;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
@@ -25,11 +26,16 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize;
static constexpr index_t NPerBlockBQ = NPerBlock / QuantGroupSize::kN;
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
static_assert(KPerBlock % QuantGroupSize == 0,
"KPerBlock must be a multiple of QuantGroupSize");
static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize");
static_assert(NPerBlock % QuantGroupSize::kN == 0,
"NPerBlock must be a multiple of QuantGroupSize::kN");
static_assert(KPerBlock % QuantGroupSize::kK == 0,
"KPerBlock must be a multiple of QuantGroupSize::kK");
// Create DRAM tile window for BQ
template <typename BQDramBlockWindowTmp>
@@ -38,7 +44,7 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using YPerTile = number<NPerBlock>;
using YPerTile = number<NPerBlockBQ>;
using XPerTile = number<KPerBlockBQ>;
auto bq_copy_dram_window =

View File

@@ -21,11 +21,12 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::kQuantGroupSize;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlock, KPerBlockBQ>();
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
}
template <typename Problem>
@@ -36,9 +37,9 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::kQuantGroupSize;
constexpr index_t VecLoadSize = GetVectorSizeBQ<Problem>();
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
@@ -49,12 +50,13 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
Problem::TransposeC>;
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using TileEncodingPattern = tile_distribution_encoding_pattern_bq<BlockGemmShape,
WarpGemm,
BlockSize,
KPerBlockBQ,
NPerBlock,
VecLoadSize>;
using TileEncodingPattern =
tile_distribution_encoding_pattern_bq<BlockGemmShape,
WarpGemm,
BlockSize,
KPerBlockBQ,
NPerBlockBQ,
Problem::QuantGroupSize::kN>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
@@ -65,8 +67,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
static_assert(Problem::kQuantGroupSize % WarpTile::at(I2) == 0,
"KPerWarpGemm must be a multiple of kQuantGroupSize!");
static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0,
"KPerWarpGemm must be a multiple of QuantGroupSize!");
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,

View File

@@ -91,7 +91,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
@@ -111,12 +113,13 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN;
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
@@ -151,7 +154,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
concat('x', kPadM, kPadN, kPadK), "QuantGroupSize", QuantGroupSize);
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
// clang-format on
}
@@ -207,7 +210,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "QuantGroupSize: " << QuantGroupSize << "\n"
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
@@ -255,7 +258,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
static_assert(KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
"Bq block window has incorrect lengths for defined BqLayout!");
static_assert(is_a_col_major

View File

@@ -171,11 +171,9 @@ template <typename BlockGemmShape,
index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize>
index_t XPerQ>
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
@@ -186,34 +184,94 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
static constexpr index_t NIterPerWarp = BlockGemmShape::kN / (NWarps * WarpGemm::kN);
static_assert(num_warps == MWarps * NWarps * KWarps);
// KWarps > 1 isn't supported
static_assert(KWarps == 1);
// # of elements per thread
static constexpr index_t Y = YPerTile;
static constexpr index_t YR = 1;
// Number of iters per warp
// MIters are indexed using (Y0, Y1)
static constexpr index_t X0 = NIterPerWarp;
// # of warps in Y dim
static constexpr index_t X1 = NWarps;
static constexpr index_t X2 = WarpGemm::kN;
static_assert(X0 * X1 * X2 == XPerTile, "X0, X1, X2 must cover the blocktile along Y.");
/// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales)
///
/// This function determines the optimal thread distribution pattern for loading and applying
/// quantization scales to the B matrix based on the quantization group size (XPerQ) relative
/// to warp dimensions.
///
/// Three distinct distribution patterns are handled:
///
/// 1. Fine-grained quantization (XPerQ < WarpGemm::kN):
/// - Multiple quantization groups exist within a single warp's N-dimension
/// - Each warp processes multiple scales (WarpGemm::kN / XPerQ scales per warp)
/// - Distribution includes explicit replication factor (XR = XPerQ) for scale broadcast
/// - Example: XPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp
///
/// 2. Medium-grained quantization (WarpGemm::kN <= XPerQ <= WarpGemm::kN * NWarps):
/// - Each warp handles exactly one quantization scale
/// - Scales are distributed across warps with replication factor XR = XPerQ / WarpGemm::kN
/// - Example: XPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4
///
/// 3. Coarse-grained quantization (XPerQ > WarpGemm::kN * NWarps):
/// - Quantization group spans multiple warps
/// - All warps share the same scale value
/// - Example: XPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale
///
/// @return A static tile distribution encoding for the BQ scale tensor
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, YR>,
tuple<sequence<Y>, sequence<X0, X1, X2>>,
tuple<sequence<0, 2>, sequence<0, 2>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<2, 1>,
sequence<0, 0>>{});
if constexpr(XPerQ < WarpGemm::kN)
{
// Case 1: Fine-grained - multiple quantization scales within a single warp
constexpr index_t Y = YPerTile; // Full Y dimension of tile
constexpr index_t YR = 1; // No Y replication needed
constexpr index_t X0 = NIterPerWarp; // Iterations per warp in N-dim
constexpr index_t X1 = NWarps; // Number of warps in N-dim
constexpr index_t X2 = WarpGemm::kN / XPerQ; // Number of scales per warp
constexpr index_t XR = XPerQ; // Elements per quantization group
static_assert(X0 * X1 * X2 == XPerTile, "X0, X1, X2 must cover the blocktile along X.");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, YR, XR>,
tuple<sequence<Y>, sequence<X0, X1, X2>>,
tuple<sequence<0, 2>, sequence<0, 2, 0>>,
tuple<sequence<0, 1>, sequence<1, 2, 2>>,
sequence<2, 1>,
sequence<0, 0>>{});
}
else if constexpr(XPerQ <= WarpGemm::kN * NWarps)
{
// Case 2: Medium-grained - one quantization scale per warp
constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor
constexpr auto X1 = NWarps / XR; // Warps per unique scale
constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, XR, get_warp_size()>,
tuple<sequence<YPerTile>, sequence<X0, X1>>,
tuple<sequence<0, 2, 0>, sequence<0>>,
tuple<sequence<0, 1, 1>, sequence<2>>,
sequence<2, 1>,
sequence<0, 0>>{});
}
else // XPerQ > WarpGemm::kN * NWarps
{
// Case 3: Coarse-grained - quantization group spans all warps
// All warps in N-dimension share the same quantization scale
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, get_warp_size()>,
tuple<sequence<YPerTile>, sequence<XPerTile>>,
tuple<sequence<0, 0>, sequence<0>>,
tuple<sequence<0, 1>, sequence<2>>,
sequence<2, 1>,
sequence<0, 0>>{});
}
}
};
template <typename GroupSizes>
struct QuantGroupShape
{
static constexpr index_t kM = GroupSizes::at(number<0>{});
static constexpr index_t kN = GroupSizes::at(number<1>{});
static constexpr index_t kK = GroupSizes::at(number<2>{});
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
return concat('_', "quant_group_shape", concat('x', kM, kN, kK));
}
};

View File

@@ -18,7 +18,7 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
uint32_t QuantGroupSize_,
typename QuantGroupSize_,
bool TransposeC_,
typename ComputeDataType_ = BDataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
@@ -48,6 +48,7 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
using BQDataType = remove_cvref_t<BQDataType_>;
using BlockGemmShape = typename Base::BlockGemmShape;
using QuantGroupSize = QuantGroupSize_;
using typename Base::ALayout;
using typename Base::BLayout;
@@ -67,12 +68,13 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
using AQLayout = remove_cvref_t<typename Traits::AQLayout>;
using BQLayout = remove_cvref_t<typename Traits::BQLayout>;
static constexpr uint32_t kQuantGroupSize = QuantGroupSize_;
static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
static_assert(BlockGemmShape::kK % kQuantGroupSize == 0);
static_assert(BlockGemmShape::kM % QuantGroupSize::kM == 0);
static_assert(BlockGemmShape::kN % QuantGroupSize::kN == 0);
static_assert(BlockGemmShape::kK % QuantGroupSize::kK == 0);
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
@@ -81,8 +83,7 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
concat('x', VectorLoadSize, kBlockSize),
concat('x', kPadM, kPadN, kPadK),
Scheduler,
"QuantGroupSize",
kQuantGroupSize);
QuantGroupSize::GetName());
// clang-format on
}
@@ -111,7 +112,7 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
uint32_t QuantGroupSize_,
typename QuantGroupSize_,
bool TransposeC_,
typename ComputeDataType_ = BDataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
@@ -137,7 +138,7 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
uint32_t QuantGroupSize_,
typename QuantGroupSize_,
typename ComputeDataType_ = ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
@@ -175,7 +176,7 @@ using GemmRowColTensorQuantPipelineProblem =
CDataType_,
BlockGemmShape_,
Traits_,
1, // no group size applicable
QuantGroupShape<sequence<1, 1, 1>>, // no group size applicable
TransposeC_,
ComputeDataType_,
Scheduler_,

View File

@@ -15,10 +15,11 @@ struct GemmWPQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipelin
{
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::kQuantGroupSize;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlock, KPerBlockBQ>();
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
}
template <typename Problem>

View File

@@ -25,6 +25,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
@@ -68,10 +69,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
using Base::m_preload;
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize;
static constexpr index_t KPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
static constexpr index_t QScalesPerBlockRow =
(kKPerBlock + QuantGroupSize - 1) / QuantGroupSize;
integer_divide_ceil(kKPerBlock, QuantGroupSize::kK);
static constexpr index_t GetVectorSizeBQ()
{
@@ -89,7 +90,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', Base::GetVectorSizeA(), Base::GetVectorSizeB(), GetVectorSizeBQ()),
concat('x', kPadM, kPadN, kPadK), QuantGroupSize);
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
// clang-format on
}