[CK_TILE] B matrix 2D block scale gemm (#3074)

* Refactor quant group size to be configurable for M/N/K, not just K

* add some asserts for configurations not implemented

* start setting of group size for N dimension

* enable 2d for reference quant gemm

* WIP: trying to figure out tile dstr and/or indexing for scale matrix

* WIP

* Fix handling of n dim blocks in tile windows etc

* remove commented code and enable all tests again

* fix formatting

* Add more specialized tile distributions

* Enable NWarps replication for bquant tile dstr

* fix formatting

* fix format

* Fix some issues from the merge

* fix formatting

* one more fix to tile dstr, and revert debug initialization

* Remove commented code

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* simplify conditions that are needed for tile distributions

* only enable the working group sizes in tests

* fix formatting

* Update tile distribution for 2D bquant

* add some documentation and 2d block scale example

* fix formatting

* Add in Changlog and restructure the quant 2d example

* fix CMake

* support the change for blockscale 2d

* fix the test file

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Sami Remes
2025-11-03 00:49:20 +00:00
committed by GitHub
parent 73f637894d
commit 16e85cf179
24 changed files with 476 additions and 363 deletions

View File

@@ -10,7 +10,7 @@ namespace ck_tile {
// A is block window on shared memory
// BQ (scale tensor) is block distributed tensor.
// Consecutive kQuantGroupSize elements of B are quantized with a separate scale.
// Consecutive QuantGroupSize elements of B are quantized with a separate scale.
// B is block window on block distributed tensor.
// C is block distributed tensor
template <typename Problem_, typename BlockPolicy_>
@@ -24,6 +24,10 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
@@ -47,8 +51,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
static constexpr index_t NIterPerWarp =
@@ -58,13 +61,12 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
static constexpr auto MIter_2nd_last =
(MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
static constexpr index_t KPerBlockBQ = KPerBlock / kQuantGroupSize;
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
static constexpr index_t QScalesPerBlockRow =
(KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize;
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
static constexpr index_t QScalesPerWarpGemmRow =
(WG::kK + kQuantGroupSize - 1) / kQuantGroupSize;
integer_divide_ceil(WG::kK, QuantGroupSize::kK);
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read

View File

@@ -46,7 +46,7 @@ struct BlockGemmAQuantBase
// A is block window on shared memory
// AQ (scale tensor) is block distributed tensor.
// Consecutive kQuantGroupSize elements of A are quantized with a separate scale.
// Consecutive QuantGroupSize elements of A are quantized with a separate scale.
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_,
@@ -66,16 +66,16 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
// Threadblock GEMM tile size
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t AQPerBlock = KPerBlock / kQuantGroupSize;
static constexpr index_t AQPerBlock = KPerBlock / QuantGroupSize::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
@@ -101,20 +101,20 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr index_t QScalesPerBlockRow =
(KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize;
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
static constexpr index_t QScalesPerWarpGemmRow =
(WarpGemm::kK + kQuantGroupSize - 1) / kQuantGroupSize;
integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
static_assert(kQuantGroupSize % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of kQuantGroupSize");
static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
static_assert(QScalesPerWarpGemmRow == 1,
"Error! kQuantGroupSize shouldn't be smaller than WarpGemm::kK");
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
static_assert(KPerBlock / kQuantGroupSize > 0,
static_assert(KPerBlock / QuantGroupSize::kK > 0,
"Error! Each row of blockgemm should have a separate scale");
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,

View File

@@ -46,7 +46,7 @@ struct BlockGemmBQuantBase
// A is block window on shared memory
// BQ (scale tensor) is block distributed tensor.
// Consecutive kQuantGroupSize elements of B are quantized with a separate scale.
// Consecutive QuantGroupSize elements of B are quantized with a separate scale.
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_,
@@ -66,16 +66,18 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
// Threadblock GEMM tile size
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t BQPerBlock = KPerBlock / kQuantGroupSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NQPerBlock = NPerBlock / QuantGroupSize::kN;
static constexpr index_t KQPerBlock = KPerBlock / QuantGroupSize::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
@@ -101,20 +103,20 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr index_t QScalesPerBlockRow =
(KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize;
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
static constexpr index_t QScalesPerWarpGemmRow =
(WarpGemm::kK + kQuantGroupSize - 1) / kQuantGroupSize;
integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
static_assert(kQuantGroupSize % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of kQuantGroupSize");
static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
static_assert(QScalesPerWarpGemmRow == 1,
"Error! kQuantGroupSize shouldn't be smaller than WarpGemm::kK");
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
static_assert(KPerBlock / kQuantGroupSize > 0,
static_assert(KPerBlock / QuantGroupSize::kK > 0,
"Error! Each row of blockgemm should have a separate scale");
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
@@ -340,23 +342,17 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
}
});
// Need to multiply bquant with accumulated C
//
// The accumulated C tile has the standard distribution. For example
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
// [26,0], [27,0].
//
// These elements are in different rows, need to get the scale value
// for the corresponding row.
// Based on bquant's tile distribution, it can be inferred which
// lane holds the relevant scale. For example, the scales corresponding
// to the 16 elements held by lane 0 are held by lanes 0, 1, 2, 3, 8, 9,
// 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 respectively.
//
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.
constexpr index_t reg_offset = nIter * Traits::BQPerBlock + kQScale;
// Multiply bquant with accumulated C
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::QuantGroupSize::kN >= (NWarp * WarpGemm::kN))
return (nIter * NWarp * WarpGemm::kN) /
GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock +
kQScale;
else
{
return nIter * Traits::KQPerBlock + kQScale;
}
}();
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(