mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] B matrix 2D block scale gemm (#3074)
* Refactor quant group size to be configurable for M/N/K, not just K * add some asserts for configurations not implemented * start setting of group size for N dimension * enable 2d for reference quant gemm * WIP: trying to figure out tile dstr and/or indexing for scale matrix * WIP * Fix handling of n dim blocks in tile windows etc * remove commented code and enable all tests again * fix formatting * Add more specialized tile distributions * Enable NWarps replication for bquant tile dstr * fix formatting * fix format * Fix some issues from the merge * fix formatting * one more fix to tile dstr, and revert debug initialization * Remove commented code Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * simplify conditions that are needed for tile distributions * only enable the working group sizes in tests * fix formatting * Update tile distribution for 2D bquant * add some documentation and 2d block scale example * fix formatting * Add in Changlog and restructure the quant 2d example * fix CMake * support the change for blockscale 2d * fix the test file --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Cong Ma <congma13@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -10,7 +10,7 @@ namespace ck_tile {
|
||||
|
||||
// A is block window on shared memory
|
||||
// BQ (scale tensor) is block distributed tensor.
|
||||
// Consecutive kQuantGroupSize elements of B are quantized with a separate scale.
|
||||
// Consecutive QuantGroupSize elements of B are quantized with a separate scale.
|
||||
// B is block window on block distributed tensor.
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename BlockPolicy_>
|
||||
@@ -24,6 +24,10 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
|
||||
static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
@@ -47,8 +51,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
static constexpr index_t NIterPerWarp =
|
||||
@@ -58,13 +61,12 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
static constexpr auto MIter_2nd_last =
|
||||
(MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
|
||||
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / kQuantGroupSize;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
(KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize;
|
||||
|
||||
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerWarpGemmRow =
|
||||
(WG::kK + kQuantGroupSize - 1) / kQuantGroupSize;
|
||||
integer_divide_ceil(WG::kK, QuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
|
||||
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
|
||||
|
||||
@@ -46,7 +46,7 @@ struct BlockGemmAQuantBase
|
||||
|
||||
// A is block window on shared memory
|
||||
// AQ (scale tensor) is block distributed tensor.
|
||||
// Consecutive kQuantGroupSize elements of A are quantized with a separate scale.
|
||||
// Consecutive QuantGroupSize elements of A are quantized with a separate scale.
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_,
|
||||
@@ -66,16 +66,16 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
// Threadblock GEMM tile size
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t AQPerBlock = KPerBlock / kQuantGroupSize;
|
||||
static constexpr index_t AQPerBlock = KPerBlock / QuantGroupSize::kK;
|
||||
|
||||
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
@@ -101,20 +101,20 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
(KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize;
|
||||
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerWarpGemmRow =
|
||||
(WarpGemm::kK + kQuantGroupSize - 1) / kQuantGroupSize;
|
||||
integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
|
||||
|
||||
static_assert(kQuantGroupSize % WarpGemm::kK == 0,
|
||||
"Error! WarpGemm::kK should be a multiple of kQuantGroupSize");
|
||||
static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
|
||||
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
|
||||
static_assert(QScalesPerWarpGemmRow == 1,
|
||||
"Error! kQuantGroupSize shouldn't be smaller than WarpGemm::kK");
|
||||
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
|
||||
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
|
||||
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
|
||||
|
||||
static_assert(KPerBlock / kQuantGroupSize > 0,
|
||||
static_assert(KPerBlock / QuantGroupSize::kK > 0,
|
||||
"Error! Each row of blockgemm should have a separate scale");
|
||||
|
||||
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
|
||||
|
||||
@@ -46,7 +46,7 @@ struct BlockGemmBQuantBase
|
||||
|
||||
// A is block window on shared memory
|
||||
// BQ (scale tensor) is block distributed tensor.
|
||||
// Consecutive kQuantGroupSize elements of B are quantized with a separate scale.
|
||||
// Consecutive QuantGroupSize elements of B are quantized with a separate scale.
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_,
|
||||
@@ -66,16 +66,18 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
// Threadblock GEMM tile size
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t BQPerBlock = KPerBlock / kQuantGroupSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NQPerBlock = NPerBlock / QuantGroupSize::kN;
|
||||
static constexpr index_t KQPerBlock = KPerBlock / QuantGroupSize::kK;
|
||||
|
||||
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
@@ -101,20 +103,20 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
(KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize;
|
||||
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerWarpGemmRow =
|
||||
(WarpGemm::kK + kQuantGroupSize - 1) / kQuantGroupSize;
|
||||
integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
|
||||
|
||||
static_assert(kQuantGroupSize % WarpGemm::kK == 0,
|
||||
"Error! WarpGemm::kK should be a multiple of kQuantGroupSize");
|
||||
static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
|
||||
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
|
||||
static_assert(QScalesPerWarpGemmRow == 1,
|
||||
"Error! kQuantGroupSize shouldn't be smaller than WarpGemm::kK");
|
||||
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
|
||||
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
|
||||
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
|
||||
|
||||
static_assert(KPerBlock / kQuantGroupSize > 0,
|
||||
static_assert(KPerBlock / QuantGroupSize::kK > 0,
|
||||
"Error! Each row of blockgemm should have a separate scale");
|
||||
|
||||
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
|
||||
@@ -340,23 +342,17 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
}
|
||||
});
|
||||
|
||||
// Need to multiply bquant with accumulated C
|
||||
//
|
||||
// The accumulated C tile has the standard distribution. For example
|
||||
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
|
||||
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
|
||||
// [26,0], [27,0].
|
||||
//
|
||||
// These elements are in different rows, need to get the scale value
|
||||
// for the corresponding row.
|
||||
// Based on bquant's tile distribution, it can be inferred which
|
||||
// lane holds the relevant scale. For example, the scales corresponding
|
||||
// to the 16 elements held by lane 0 are held by lanes 0, 1, 2, 3, 8, 9,
|
||||
// 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 respectively.
|
||||
//
|
||||
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.
|
||||
|
||||
constexpr index_t reg_offset = nIter * Traits::BQPerBlock + kQScale;
|
||||
// Multiply bquant with accumulated C
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::QuantGroupSize::kN >= (NWarp * WarpGemm::kN))
|
||||
return (nIter * NWarp * WarpGemm::kN) /
|
||||
GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock +
|
||||
kQScale;
|
||||
else
|
||||
{
|
||||
return nIter * Traits::KQPerBlock + kQScale;
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
|
||||
@@ -685,9 +685,10 @@ struct QuantGemmKernel
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(kargs.QK_B, kargs.N),
|
||||
make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
|
||||
make_tuple(1, kargs.stride_BQ),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
@@ -831,10 +832,10 @@ struct QuantGemmKernel
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr auto aqk_per_block =
|
||||
TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize;
|
||||
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
constexpr auto tile_window_width =
|
||||
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
|
||||
constexpr auto tile_window_height = block_m / warp_m;
|
||||
@@ -847,11 +848,12 @@ struct QuantGemmKernel
|
||||
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto block_k = TilePartitioner::KPerBlock;
|
||||
return make_tile_window(
|
||||
aq_pad_view,
|
||||
make_tuple(number<block_m>{}, number<block_k / GemmPipeline::QuantGroupSize>{}),
|
||||
make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
@@ -907,11 +909,12 @@ struct QuantGemmKernel
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
|
||||
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
|
||||
{0, i_n / QuantGroupSize::kN});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -18,6 +18,7 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
|
||||
|
||||
@@ -25,10 +26,9 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize;
|
||||
static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize::kK;
|
||||
|
||||
static_assert(KPerBlock % QuantGroupSize == 0,
|
||||
static_assert(KPerBlock % QuantGroupSize::kK == 0,
|
||||
"KPerBlock must be a multiple of QuantGroupSize");
|
||||
|
||||
// Create DRAM tile window for AQ
|
||||
|
||||
@@ -86,6 +86,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem<Prob
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
|
||||
static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
@@ -106,12 +110,11 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem<Prob
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize;
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
@@ -147,7 +150,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem<Prob
|
||||
BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
|
||||
concat('x', kPadM, kPadN, kPadK), "QuantGroupSize", QuantGroupSize,
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName(),
|
||||
Scheduler == GemmPipelineScheduler::Interwave ? "interwave" : "intrawave"); // else Intrawave
|
||||
// clang-format on
|
||||
}
|
||||
@@ -204,7 +207,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem<Prob
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
|
||||
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
|
||||
<< "PrefetchStages: " << PrefetchStages << "\n";
|
||||
return str.str();
|
||||
|
||||
@@ -22,7 +22,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
|
||||
static_assert(std::is_same_v<AQLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
return GetABQGlobalVectorLoadSize<Problem, AQDataType, MPerBlock, KPerBlockAQ>();
|
||||
@@ -37,7 +37,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
|
||||
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
@@ -99,8 +99,8 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
static_assert(Problem::kQuantGroupSize % WarpTile::at(I2) == 0,
|
||||
"KPerWarpGemm must be a multiple of kQuantGroupSize!");
|
||||
static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0,
|
||||
"KPerWarpGemm must be a multiple of QuantGroupSize::kK!");
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
|
||||
@@ -91,6 +91,10 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
|
||||
static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
@@ -111,12 +115,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize;
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
@@ -152,7 +155,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
|
||||
concat('x', kPadM, kPadN, kPadK), "QuantGroupSize", QuantGroupSize);
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -208,7 +211,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
|
||||
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
|
||||
<< "PrefetchStages: " << PrefetchStages << "\n";
|
||||
return str.str();
|
||||
|
||||
@@ -18,6 +18,7 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
|
||||
@@ -25,11 +26,16 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize;
|
||||
static constexpr index_t NPerBlockBQ = NPerBlock / QuantGroupSize::kN;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
|
||||
|
||||
static_assert(KPerBlock % QuantGroupSize == 0,
|
||||
"KPerBlock must be a multiple of QuantGroupSize");
|
||||
static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
|
||||
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize");
|
||||
|
||||
static_assert(NPerBlock % QuantGroupSize::kN == 0,
|
||||
"NPerBlock must be a multiple of QuantGroupSize::kN");
|
||||
static_assert(KPerBlock % QuantGroupSize::kK == 0,
|
||||
"KPerBlock must be a multiple of QuantGroupSize::kK");
|
||||
|
||||
// Create DRAM tile window for BQ
|
||||
template <typename BQDramBlockWindowTmp>
|
||||
@@ -38,7 +44,7 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
|
||||
using YPerTile = number<NPerBlock>;
|
||||
using YPerTile = number<NPerBlockBQ>;
|
||||
using XPerTile = number<KPerBlockBQ>;
|
||||
|
||||
auto bq_copy_dram_window =
|
||||
|
||||
@@ -21,11 +21,12 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::kQuantGroupSize;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
|
||||
static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlock, KPerBlockBQ>();
|
||||
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -36,9 +37,9 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::kQuantGroupSize;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeBQ<Problem>();
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
@@ -49,12 +50,13 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
Problem::TransposeC>;
|
||||
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_bq<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
KPerBlockBQ,
|
||||
NPerBlock,
|
||||
VecLoadSize>;
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_bq<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
KPerBlockBQ,
|
||||
NPerBlockBQ,
|
||||
Problem::QuantGroupSize::kN>;
|
||||
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
@@ -65,8 +67,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
static_assert(Problem::kQuantGroupSize % WarpTile::at(I2) == 0,
|
||||
"KPerWarpGemm must be a multiple of kQuantGroupSize!");
|
||||
static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0,
|
||||
"KPerWarpGemm must be a multiple of QuantGroupSize!");
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
|
||||
@@ -91,7 +91,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
@@ -111,12 +113,13 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize;
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN;
|
||||
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
@@ -151,7 +154,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
|
||||
concat('x', kPadM, kPadN, kPadK), "QuantGroupSize", QuantGroupSize);
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -207,7 +210,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
|
||||
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
|
||||
<< "PrefetchStages: " << PrefetchStages << "\n";
|
||||
return str.str();
|
||||
@@ -255,7 +258,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
|
||||
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
|
||||
static_assert(KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlock == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
|
||||
NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
|
||||
"Bq block window has incorrect lengths for defined BqLayout!");
|
||||
|
||||
static_assert(is_a_col_major
|
||||
|
||||
@@ -171,11 +171,9 @@ template <typename BlockGemmShape,
|
||||
index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize>
|
||||
index_t XPerQ>
|
||||
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
|
||||
{
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
|
||||
@@ -186,34 +184,94 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
static constexpr index_t NIterPerWarp = BlockGemmShape::kN / (NWarps * WarpGemm::kN);
|
||||
|
||||
static_assert(num_warps == MWarps * NWarps * KWarps);
|
||||
|
||||
// KWarps > 1 isn't supported
|
||||
static_assert(KWarps == 1);
|
||||
|
||||
// # of elements per thread
|
||||
static constexpr index_t Y = YPerTile;
|
||||
static constexpr index_t YR = 1;
|
||||
|
||||
// Number of iters per warp
|
||||
// MIters are indexed using (Y0, Y1)
|
||||
static constexpr index_t X0 = NIterPerWarp;
|
||||
|
||||
// # of warps in Y dim
|
||||
static constexpr index_t X1 = NWarps;
|
||||
|
||||
static constexpr index_t X2 = WarpGemm::kN;
|
||||
|
||||
static_assert(X0 * X1 * X2 == XPerTile, "X0, X1, X2 must cover the blocktile along Y.");
|
||||
|
||||
/// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales)
|
||||
///
|
||||
/// This function determines the optimal thread distribution pattern for loading and applying
|
||||
/// quantization scales to the B matrix based on the quantization group size (XPerQ) relative
|
||||
/// to warp dimensions.
|
||||
///
|
||||
/// Three distinct distribution patterns are handled:
|
||||
///
|
||||
/// 1. Fine-grained quantization (XPerQ < WarpGemm::kN):
|
||||
/// - Multiple quantization groups exist within a single warp's N-dimension
|
||||
/// - Each warp processes multiple scales (WarpGemm::kN / XPerQ scales per warp)
|
||||
/// - Distribution includes explicit replication factor (XR = XPerQ) for scale broadcast
|
||||
/// - Example: XPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp
|
||||
///
|
||||
/// 2. Medium-grained quantization (WarpGemm::kN <= XPerQ <= WarpGemm::kN * NWarps):
|
||||
/// - Each warp handles exactly one quantization scale
|
||||
/// - Scales are distributed across warps with replication factor XR = XPerQ / WarpGemm::kN
|
||||
/// - Example: XPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4
|
||||
///
|
||||
/// 3. Coarse-grained quantization (XPerQ > WarpGemm::kN * NWarps):
|
||||
/// - Quantization group spans multiple warps
|
||||
/// - All warps share the same scale value
|
||||
/// - Example: XPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale
|
||||
///
|
||||
/// @return A static tile distribution encoding for the BQ scale tensor
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, YR>,
|
||||
tuple<sequence<Y>, sequence<X0, X1, X2>>,
|
||||
tuple<sequence<0, 2>, sequence<0, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{});
|
||||
if constexpr(XPerQ < WarpGemm::kN)
|
||||
{
|
||||
// Case 1: Fine-grained - multiple quantization scales within a single warp
|
||||
constexpr index_t Y = YPerTile; // Full Y dimension of tile
|
||||
constexpr index_t YR = 1; // No Y replication needed
|
||||
constexpr index_t X0 = NIterPerWarp; // Iterations per warp in N-dim
|
||||
constexpr index_t X1 = NWarps; // Number of warps in N-dim
|
||||
constexpr index_t X2 = WarpGemm::kN / XPerQ; // Number of scales per warp
|
||||
constexpr index_t XR = XPerQ; // Elements per quantization group
|
||||
|
||||
static_assert(X0 * X1 * X2 == XPerTile, "X0, X1, X2 must cover the blocktile along X.");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, YR, XR>,
|
||||
tuple<sequence<Y>, sequence<X0, X1, X2>>,
|
||||
tuple<sequence<0, 2>, sequence<0, 2, 0>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
else if constexpr(XPerQ <= WarpGemm::kN * NWarps)
|
||||
{
|
||||
// Case 2: Medium-grained - one quantization scale per warp
|
||||
constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor
|
||||
constexpr auto X1 = NWarps / XR; // Warps per unique scale
|
||||
constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, XR, get_warp_size()>,
|
||||
tuple<sequence<YPerTile>, sequence<X0, X1>>,
|
||||
tuple<sequence<0, 2, 0>, sequence<0>>,
|
||||
tuple<sequence<0, 1, 1>, sequence<2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
else // XPerQ > WarpGemm::kN * NWarps
|
||||
{
|
||||
// Case 3: Coarse-grained - quantization group spans all warps
|
||||
// All warps in N-dimension share the same quantization scale
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, NWarps, get_warp_size()>,
|
||||
tuple<sequence<YPerTile>, sequence<XPerTile>>,
|
||||
tuple<sequence<0, 0>, sequence<0>>,
|
||||
tuple<sequence<0, 1>, sequence<2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GroupSizes>
|
||||
struct QuantGroupShape
|
||||
{
|
||||
static constexpr index_t kM = GroupSizes::at(number<0>{});
|
||||
static constexpr index_t kN = GroupSizes::at(number<1>{});
|
||||
static constexpr index_t kK = GroupSizes::at(number<2>{});
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
return concat('_', "quant_group_shape", concat('x', kM, kN, kK));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
uint32_t QuantGroupSize_,
|
||||
typename QuantGroupSize_,
|
||||
bool TransposeC_,
|
||||
typename ComputeDataType_ = BDataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
@@ -48,6 +48,7 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
|
||||
using BQDataType = remove_cvref_t<BQDataType_>;
|
||||
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using QuantGroupSize = QuantGroupSize_;
|
||||
|
||||
using typename Base::ALayout;
|
||||
using typename Base::BLayout;
|
||||
@@ -67,12 +68,13 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
|
||||
using AQLayout = remove_cvref_t<typename Traits::AQLayout>;
|
||||
using BQLayout = remove_cvref_t<typename Traits::BQLayout>;
|
||||
|
||||
static constexpr uint32_t kQuantGroupSize = QuantGroupSize_;
|
||||
static constexpr auto Scheduler = Scheduler_;
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
static constexpr auto Scheduler = Scheduler_;
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
|
||||
static_assert(BlockGemmShape::kK % kQuantGroupSize == 0);
|
||||
static_assert(BlockGemmShape::kM % QuantGroupSize::kM == 0);
|
||||
static_assert(BlockGemmShape::kN % QuantGroupSize::kN == 0);
|
||||
static_assert(BlockGemmShape::kK % QuantGroupSize::kK == 0);
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -81,8 +83,7 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
|
||||
concat('x', VectorLoadSize, kBlockSize),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
Scheduler,
|
||||
"QuantGroupSize",
|
||||
kQuantGroupSize);
|
||||
QuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -111,7 +112,7 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
uint32_t QuantGroupSize_,
|
||||
typename QuantGroupSize_,
|
||||
bool TransposeC_,
|
||||
typename ComputeDataType_ = BDataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
@@ -137,7 +138,7 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
uint32_t QuantGroupSize_,
|
||||
typename QuantGroupSize_,
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
@@ -175,7 +176,7 @@ using GemmRowColTensorQuantPipelineProblem =
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
1, // no group size applicable
|
||||
QuantGroupShape<sequence<1, 1, 1>>, // no group size applicable
|
||||
TransposeC_,
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
|
||||
@@ -15,10 +15,11 @@ struct GemmWPQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipelin
|
||||
{
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::kQuantGroupSize;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
|
||||
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlock, KPerBlockBQ>();
|
||||
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -25,6 +25,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
@@ -68,10 +69,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
|
||||
using Base::m_preload;
|
||||
|
||||
static constexpr index_t QuantGroupSize = Problem::kQuantGroupSize;
|
||||
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize;
|
||||
static constexpr index_t KPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
(kKPerBlock + QuantGroupSize - 1) / QuantGroupSize;
|
||||
integer_divide_ceil(kKPerBlock, QuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t GetVectorSizeBQ()
|
||||
{
|
||||
@@ -89,7 +90,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', Base::GetVectorSizeA(), Base::GetVectorSizeB(), GetVectorSizeBQ()),
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize);
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user