[CK_Tile] Adding support for preshuffleQuant in AB quant Block Scale Gemm (#3629)

* initial commit

* preshuffleQuant support for ABQuant

* fix mxfp4 to use correct QuantGroupSize

* addressing review comments and seperated Preshufflequant for A and B

* updated grouped gemm example for updated traits definition

* fix for CI failure

* updated grouped_gemm_abquant test for updated traits definition

* updated grouped_gemm_abquant test for updated traits definition
This commit is contained in:
Khushbu Agarwal
2026-01-28 19:45:09 -08:00
committed by GitHub
parent e3556fed04
commit 9b168082b7
33 changed files with 490 additions and 367 deletions

View File

@@ -96,9 +96,9 @@ struct AQPickerCommon : public BlockGemmQuantBase
if constexpr(Traits::TransposeC) // transposed C
{
index_t reg_offset =
Traits::PreshuffleQuant ? mIter : mIter * Traits::AQPerBlock + kQScale;
Traits::APreshuffleQuant ? mIter : mIter * Traits::AQPerBlock + kQScale;
auto scale_reg = aq_block_tensor.get_thread_buffer()[reg_offset];
if constexpr(Traits::PreshuffleQuant)
if constexpr(Traits::APreshuffleQuant)
{
auto pull_from_lane =
(__lane_id() & (Traits::WarpGemm::kN - 1)) * Traits::AQPerBlock + kQScale;
@@ -121,7 +121,7 @@ struct AQPickerCommon : public BlockGemmQuantBase
}
else
{
if constexpr(Traits::PreshuffleQuant)
if constexpr(Traits::APreshuffleQuant)
{
// A view is created on top of the preshuffled AQ, where each row of
// the view is composed of a row from a warp tile within an AQ block

View File

@@ -69,7 +69,8 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
@@ -127,9 +128,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
@@ -162,12 +163,12 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
static constexpr auto MIter_2nd_last =
(MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;
static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, QuantGroupSize::kK); // 128 / 128 = 1
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); // 128 / 128 = 1
static constexpr index_t QScalesPerWarpGemmRow =
integer_divide_ceil(WG::kK, QuantGroupSize::kK);
integer_divide_ceil(WG::kK, BQuantGroupSize::kK);
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; // 8 / 1 = 8
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
@@ -289,9 +290,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
CBlockTensor::PackedSize>{};
index_t reg_offset = [&]() {
if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN))
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
{
return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ +
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN * KPerBlockBQ +
kQScale;
}
else

View File

@@ -25,9 +25,9 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
@@ -53,7 +53,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
static constexpr index_t NIterPerWarp =
@@ -63,12 +63,12 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
static constexpr auto MIter_2nd_last =
(MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;
static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
static constexpr index_t QScalesPerWarpGemmRow =
integer_divide_ceil(WG::kK, QuantGroupSize::kK);
integer_divide_ceil(WG::kK, BQuantGroupSize::kK);
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
@@ -173,7 +173,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
constexpr index_t reg_offset = nIter;
auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale;
@@ -205,9 +205,10 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
else
{
index_t reg_offset = [&]() {
if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN))
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
{
return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ +
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN *
KPerBlockBQ +
kQScale;
}
else

View File

@@ -33,6 +33,7 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
@@ -75,7 +76,8 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
@@ -134,8 +136,12 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
using CDataType = remove_cvref_t<typename Traits::CDataType>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using OverrideBDataType = std::conditional_t<
std::is_same_v<BDataType, pk_int4_t> &&
std::is_same_v<typename Traits::BLayout, tensor_layout::gemm::RowMajor>,
ADataType,
BDataType>;
using Base = BlockGemmQuantBase;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
@@ -156,7 +162,8 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
static constexpr bool PreshuffleQuant = Traits::PreshuffleQuant;
static constexpr bool APreshuffleQuant = Traits::APreshuffleQuant;
static constexpr bool BPreshuffleQuant = Traits::BPreshuffleQuant;
static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
@@ -354,11 +361,24 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
aq_block_tensor);
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
constexpr index_t reg_offset = nIter;
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::BQuantGroupSize::kN >
(NWarp * WarpGemm::kN) &&
Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN)
{
return kQScale;
}
else
{
return nIter;
}
}();
auto pull_from_lane =
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
// cross lane ops
uint32_t scale_reg_dword;

View File

@@ -34,7 +34,7 @@ struct AQuantBlockUniversalGemmAsBsCr
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
@@ -43,7 +43,7 @@ struct AQuantBlockUniversalGemmAsBsCr
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t AQPerBlock = KPerBlock / QuantGroupSize::kK;
static constexpr index_t AQPerBlock = KPerBlock / AQuantGroupSize::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
@@ -69,20 +69,20 @@ struct AQuantBlockUniversalGemmAsBsCr
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
integer_divide_ceil(KPerBlock, AQuantGroupSize::kK);
static constexpr index_t QScalesPerWarpGemmRow =
integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);
integer_divide_ceil(WarpGemm::kK, AQuantGroupSize::kK);
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
static_assert(AQuantGroupSize::kK % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of AQuantGroupSize");
static_assert(QScalesPerWarpGemmRow == 1,
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
"Error! AQuantGroupSize shouldn't be smaller than WarpGemm::kK");
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
static_assert(KPerBlock / QuantGroupSize::kK > 0,
static_assert(KPerBlock / AQuantGroupSize::kK > 0,
"Error! Each row of blockgemm should have a separate scale");
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
@@ -110,8 +110,8 @@ struct AQuantBlockUniversalGemmAsBsCr
static constexpr index_t KPack = WarpGemm::kKPerThread;
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool TransposeC = Problem::TransposeC;
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
static constexpr bool TransposeC = Problem::TransposeC;
};
public:

View File

@@ -36,7 +36,7 @@ struct BQuantBlockUniversalGemmAsBsCr
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
@@ -46,8 +46,8 @@ struct BQuantBlockUniversalGemmAsBsCr
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NQPerBlock = NPerBlock / QuantGroupSize::kN;
static constexpr index_t KQPerBlock = KPerBlock / QuantGroupSize::kK;
static constexpr index_t NQPerBlock = NPerBlock / BQuantGroupSize::kN;
static constexpr index_t KQPerBlock = KPerBlock / BQuantGroupSize::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
@@ -72,23 +72,23 @@ struct BQuantBlockUniversalGemmAsBsCr
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
static constexpr index_t QScalesPerWarpGemmRow =
integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);
integer_divide_ceil(WarpGemm::kK, BQuantGroupSize::kK);
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
static_assert(BQuantGroupSize::kK % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of BQuantGroupSize");
static_assert(QScalesPerWarpGemmRow == 1,
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
"Error! BQuantGroupSize shouldn't be smaller than WarpGemm::kK");
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
static_assert(KPerBlock / QuantGroupSize::kK > 0,
static_assert(KPerBlock / BQuantGroupSize::kK > 0,
"Error! Each row of blockgemm should have a separate scale");
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
@@ -153,7 +153,7 @@ struct BQuantBlockUniversalGemmAsBsCr
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
static constexpr bool PreshuffleQuant = Traits::PreshuffleQuant;
static constexpr bool BPreshuffleQuant = Traits::BPreshuffleQuant;
static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
@@ -317,25 +317,21 @@ struct BQuantBlockUniversalGemmAsBsCr
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
// constexpr index_t reg_offset = nIter;
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::QuantGroupSize::kN >
(NWarp * WarpGemm::kN))
if constexpr(GemmTraits::BQuantGroupSize::kN >
(NWarp * WarpGemm::kN) &&
Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN)
{
if constexpr(Traits::NPerBlock ==
GemmTraits::QuantGroupSize::kN)
return kQScale;
else
return nIter; // for prefill needs kQscale, for decode needs
// nIter
return kQScale; // prefill: one quant group per block
}
else
{
return nIter;
return nIter; // decode or multiple groups per warp
}
}();
auto pull_from_lane =
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
@@ -370,10 +366,11 @@ struct BQuantBlockUniversalGemmAsBsCr
{
// Multiply bquant with accumulated C
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::QuantGroupSize::kN >=
if constexpr(GemmTraits::BQuantGroupSize::kN >=
(NWarp * WarpGemm::kN))
return (nIter * NWarp * WarpGemm::kN) /
GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock +
GemmTraits::BQuantGroupSize::kN *
Traits::KQPerBlock +
kQScale;
else
{

View File

@@ -67,15 +67,27 @@ struct get_bq_data_type_or<T, Default, std::void_t<typename T::BQDataType>>
};
template <typename, typename = void>
struct is_quantpreshuffle_enabled
struct is_Aquantpreshuffle_enabled
{
static constexpr bool value = false;
};
template <typename T>
struct is_quantpreshuffle_enabled<T, std::void_t<decltype(T::PreshuffleQuant)>>
struct is_Aquantpreshuffle_enabled<T, std::void_t<decltype(T::APreshuffleQuant)>>
{
static constexpr bool value = T::PreshuffleQuant;
static constexpr bool value = T::APreshuffleQuant;
};
template <typename, typename = void>
struct is_Bquantpreshuffle_enabled
{
static constexpr bool value = false;
};
template <typename T>
struct is_Bquantpreshuffle_enabled<T, std::void_t<decltype(T::BPreshuffleQuant)>>
{
static constexpr bool value = T::BPreshuffleQuant;
};
template <typename, typename = void>
@@ -206,8 +218,10 @@ struct QuantGemmKernel
typename detail::get_bq_layout_or<GemmPipeline, typename GemmPipeline::BLayout>::type>;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
static constexpr bool PreshuffleQuant =
detail::is_quantpreshuffle_enabled<GemmPipeline_>::value;
static constexpr bool APreshuffleQuant =
detail::is_Aquantpreshuffle_enabled<GemmPipeline_>::value;
static constexpr bool BPreshuffleQuant =
detail::is_Bquantpreshuffle_enabled<GemmPipeline_>::value;
static constexpr bool PreshuffleB = detail::is_preshuffleB_enabled<GemmPipeline_>::value;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
@@ -476,7 +490,7 @@ struct QuantGemmKernel
{
// Step 1: Create tensor view for AQ
const auto& aq_tensor_view = [&]() {
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
if constexpr(kQuantType == QuantType::AQuantGrouped && APreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
@@ -533,7 +547,7 @@ struct QuantGemmKernel
}
else if constexpr((kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::ABQuantGrouped) &&
!PreshuffleQuant)
!APreshuffleQuant)
{
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
@@ -571,13 +585,13 @@ struct QuantGemmKernel
// Step 2: Create tile window (no padding for AQ)
const auto& aq_block_window = [&]() {
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
if constexpr(kQuantType == QuantType::AQuantGrouped && APreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
using AQuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0);
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / AQuantGroupSize::kK;
constexpr auto tile_window_width =
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
constexpr auto tile_window_height = block_m / warp_m;
@@ -587,11 +601,19 @@ struct QuantGemmKernel
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_m_idx * tile_window_height, 0});
}
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
else if constexpr((kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::ABQuantGrouped) &&
!APreshuffleQuant)
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
using AQuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / AQuantGroupSize::kK;
constexpr auto block_m = TilePartitioner::MPerBlock;
if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>,
"ABQuantGrouped requires RowMajor AQ layout");
}
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(aq_tensor_view,
@@ -605,17 +627,6 @@ struct QuantGemmKernel
{0, i_m});
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto block_k = TilePartitioner::KPerBlock;
return make_tile_window(
aq_tensor_view,
make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
{i_m, 0});
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_tile_window(aq_tensor_view,
@@ -808,14 +819,15 @@ struct QuantGemmKernel
number<1>{},
number<1>{});
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
else if constexpr(kQuantType == QuantType::BQuantGrouped ||
kQuantType == QuantType::ABQuantGrouped)
{
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
"PreshuffleQuant with BQuantGrouped currently only supports "
"ColumnMajor BQ layout");
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
return MakePreshuffledQuantTensorView<
GemmPipeline::KPerBlockBQ,
@@ -824,48 +836,42 @@ struct QuantGemmKernel
TilePartitioner::BlockGemmShape::WarpTile::at(I1),
GemmPipeline::GetVectorSizeBQ()>(
bq_ptr,
ck_tile::integer_divide_ceil(kargs.N, QuantGroupSize::kN),
QuantGroupSize::kN,
ck_tile::integer_divide_ceil(kargs.N, BQuantGroupSize::kN),
BQuantGroupSize::kN,
kargs.QK_B);
}
else
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
"ABQuantGrouped requires ColumnMajor BQ layout");
}
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK),
integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1),
make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK),
integer_divide_ceil(kargs.N, BQuantGroupSize::kN)),
make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN),
integer_divide_ceil(kargs.K, QuantGroupSize::kK)),
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1),
make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN),
integer_divide_ceil(kargs.K, BQuantGroupSize::kK)),
make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
make_tuple(kargs.stride_BQ, 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
else
{
return nullptr;
@@ -881,28 +887,29 @@ struct QuantGemmKernel
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
else if constexpr(kQuantType == QuantType::BQuantGrouped ||
kQuantType == QuantType::ABQuantGrouped)
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
if constexpr(PreshuffleQuant)
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
if constexpr(BPreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
// Number of N-dimension quantization groups per block
constexpr auto block_n = (QuantGroupSize::kN <= TilePartitioner::NPerBlock)
? TilePartitioner::NPerBlock / QuantGroupSize::kN
: QuantGroupSize::kN / TilePartitioner::NPerBlock;
constexpr auto block_n = (BQuantGroupSize::kN <= TilePartitioner::NPerBlock)
? TilePartitioner::NPerBlock / BQuantGroupSize::kN
: BQuantGroupSize::kN / TilePartitioner::NPerBlock;
// Number of N-dimension elements per warp
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
// Determine how many warps share the same scale in N-dimension
constexpr auto warp_per_group = (QuantGroupSize::kN < warp_n)
? (warp_n / QuantGroupSize::kN)
: (QuantGroupSize::kN / warp_n);
constexpr auto warp_per_group = (BQuantGroupSize::kN < warp_n)
? (warp_n / BQuantGroupSize::kN)
: (BQuantGroupSize::kN / warp_n);
// Number of K-dimension quantization groups per block
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / BQuantGroupSize::kK;
// The pre-shuffled layout flattens warp_n ×
// bqk_per_block scales per row, Padded up to warp_size
@@ -911,25 +918,25 @@ struct QuantGemmKernel
ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
// Adapts based on fine vs coarse quantization granularity:
// - Fine-grained (QuantGroupSize::kN < warp_n):
// - Fine-grained (BQuantGroupSize::kN < warp_n):
// Multiple quant groups per warp → fewer rows needed per block.
// height = block_n / warp_per_group
//
// - Coarse-grained (QuantGroupSize::kN >= warp_n):
// - Coarse-grained (BQuantGroupSize::kN >= warp_n):
// Each row represents one quant group.
// height = block_n
constexpr auto tile_window_height =
(QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n;
(BQuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n;
auto block_n_idx = i_n / TilePartitioner::NPerBlock;
// For decode shapes GN: 128, Blocks needs to repeat 0,0,1,1,2,2 ...
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
if(BQuantGroupSize::kN > TilePartitioner::NPerBlock)
{
block_n_idx = block_n_idx >> 1;
}
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
if(BQuantGroupSize::kN > TilePartitioner::NPerBlock)
{
return make_tile_window(
bq_tensor_view,
@@ -946,17 +953,22 @@ struct QuantGemmKernel
}
else
{
if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
"ABQuantGrouped requires RowMajor AQ layout");
}
constexpr auto tensor_dim =
(QuantGroupSize::kN <= TilePartitioner::NPerBlock)
? TilePartitioner::NPerBlock / QuantGroupSize::kN
(BQuantGroupSize::kN <= TilePartitioner::NPerBlock)
? TilePartitioner::NPerBlock / BQuantGroupSize::kN
: 1;
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(
bq_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
make_tuple(number<TilePartitioner::KPerBlock / BQuantGroupSize::kK>{},
number<tensor_dim>{}),
{0, i_n / QuantGroupSize::kN});
{0, i_n / BQuantGroupSize::kN});
}
else
{
@@ -964,21 +976,11 @@ struct QuantGemmKernel
return make_tile_window(
bq_tensor_view,
make_tuple(number<tensor_dim>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
number<TilePartitioner::KPerBlock / BQuantGroupSize::kK>{}),
{i_n / BQuantGroupSize::kN, 0});
}
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
return make_tile_window(
bq_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
}
else
{
return nullptr;
@@ -1223,7 +1225,7 @@ struct QuantGemmKernel
if constexpr(kQuantType == QuantType::AQuantGrouped)
{
index_t m = 0;
if constexpr(PreshuffleQuant)
if constexpr(APreshuffleQuant)
{
m = kargs.M;
}
@@ -1233,7 +1235,7 @@ struct QuantGemmKernel
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
index_t n = 0;
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
n = kargs.N;
}
@@ -1244,9 +1246,9 @@ struct QuantGemmKernel
{
index_t m = 0;
index_t n = 0;
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
m = kargs.M;
// m = kargs.M;
n = kargs.N;
}
return GemmPipeline{}.template operator()(a_block_window,

View File

@@ -72,7 +72,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / AQuantGroupSize::kK;
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / BQuantGroupSize::kN;
static constexpr index_t NPerBlockBQ =
(BQuantGroupSize::kN <= BlockGemmShape::kN)
? integer_divide_ceil(BlockGemmShape::kN, BQuantGroupSize::kN)
: 1;
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / BQuantGroupSize::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
@@ -95,7 +98,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
@@ -264,7 +268,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
static_assert(
PreshuffleQuant ||
BPreshuffleQuant ||
(is_bq_row_major
? (KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}])
@@ -323,15 +327,18 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// only row_major for AQ
const AQDramTileWindowStep aq_dram_tile_window_step =
PreshuffleQuant
APreshuffleQuant
? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
BlockGemm::WarpGemm::kM,
0)
: (is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ));
const BQDramTileWindowStep bq_dram_tile_window_step =
(PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
0)
(BPreshuffleQuant)
? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0)
: is_bq_row_major ? make_array(KPerBlockBQ, 0)
: make_array(0, KPerBlockBQ);
@@ -484,7 +491,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
currIdx = (currIdx + 1) % 2;
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -495,7 +502,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
// Note: BDataType gets converted during loading from PkInt4
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(

View File

@@ -12,21 +12,21 @@ namespace ck_tile {
template <typename Problem, typename Policy>
struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
{
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = typename Base::ADataType;
using ALayout = typename Base::ALayout;
using BDataType = typename Base::BDataType;
using BLayout = typename Base::BLayout;
using BlockGemmShape = typename Base::BlockGemmShape;
using QuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = typename Base::ADataType;
using ALayout = typename Base::ALayout;
using BDataType = typename Base::BDataType;
using BLayout = typename Base::BLayout;
using BlockGemmShape = typename Base::BlockGemmShape;
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize::kK;
static constexpr index_t KPerBlockAQ = KPerBlock / AQuantGroupSize::kK;
static_assert(KPerBlock % QuantGroupSize::kK == 0,
static_assert(KPerBlock % AQuantGroupSize::kK == 0,
"KPerBlock must be a multiple of QuantGroupSize");
// Create DRAM tile window for AQ

View File

@@ -23,19 +23,19 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using Base = BaseGemmPipelineAgBgCrMem<Problem>;
using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
// When ADataType is pk_int4_t, use BDataType instead for transpose operations
// since packed 4-bit integers cannot be directly transposed (requires at least 8-bit precision)
using OverrideADataType =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
static_assert(AQuantGroupSize::kM == 1, "no block for M supported yet!");
static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
using I0 = number<0>;
using I1 = number<1>;
@@ -60,7 +60,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize::kK;
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / AQuantGroupSize::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
@@ -78,7 +78,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
@@ -99,7 +99,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName(),
concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(),
Scheduler == GemmPipelineScheduler::Interwave ? "interwave" : "intrawave"); // else Intrawave
// clang-format on
}
@@ -156,7 +156,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
<< "AQuantGroupSize: " << AQuantGroupSize::GetName() << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
@@ -216,7 +216,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
std::is_same_v<AQLayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(!PreshuffleQuant, "Memory pipeline does not support PreshuffleQuant!");
static_assert(!APreshuffleQuant, "Memory pipeline does not support APreshuffleQuant!");
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&

View File

@@ -32,22 +32,22 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
using BlockGemmShape = typename Problem::BlockGemmShape;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::AQuantGroupSize::kK;
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockAQ = KPerBlock / Problem::AQuantGroupSize::kK;
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
if constexpr(PreshuffleQuant)
if constexpr(APreshuffleQuant)
{
using TileEncodingPattern = tile_distribution_encoding_pattern_aq<
BlockGemmShape,
@@ -57,7 +57,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
ck_tile::integer_least_multiple(WarpGemm::kM * KPerBlockAQ, get_warp_size()),
KPerBlockAQ,
VecLoadSize,
PreshuffleQuant>;
APreshuffleQuant>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
@@ -89,7 +89,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
KPerBlockAQ,
KPerBlockAQ,
VecLoadSize,
PreshuffleQuant>;
APreshuffleQuant>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
@@ -103,7 +103,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
MPerBlock, // XPerTile
KPerBlockAQ,
VecLoadSize,
PreshuffleQuant>;
APreshuffleQuant>;
return TileEncodingPattern::make_2d_static_tile_distribution_transposed();
}
}

View File

@@ -20,19 +20,19 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
// When ADataType is pk_int4_t, use BDataType instead for transpose operations
// since packed 4-bit integers cannot be directly transposed (requires at least 8-bit precision)
using OverrideADataType =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
static_assert(AQuantGroupSize::kM == 1, "no block for M supported yet!");
static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
using I0 = number<0>;
using I1 = number<1>;
@@ -57,7 +57,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize::kK;
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / AQuantGroupSize::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
@@ -75,7 +75,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
@@ -96,7 +96,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName());
// clang-format on
}
@@ -152,7 +152,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
<< "AQuantGroupSize: " << AQuantGroupSize::GetName() << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
@@ -271,7 +271,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
// only row_major for AQ
const AQDramTileWindowStep aq_dram_tile_window_step =
PreshuffleQuant
APreshuffleQuant
? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
BlockGemm::WarpGemm::kM,
0)

View File

@@ -12,13 +12,13 @@ namespace ck_tile {
template <typename Problem, typename Policy>
struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
{
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = typename Base::ADataType;
using ALayout = typename Base::ALayout;
using BDataType = typename Base::BDataType;
using BLayout = typename Base::BLayout;
using BlockGemmShape = typename Base::BlockGemmShape;
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = typename Base::ADataType;
using ALayout = typename Base::ALayout;
using BDataType = typename Base::BDataType;
using BLayout = typename Base::BLayout;
using BlockGemmShape = typename Base::BlockGemmShape;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
@@ -27,16 +27,16 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NPerBlockBQ =
(QuantGroupSize::kN <= NPerBlock) ? NPerBlock / QuantGroupSize::kN : 1;
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
(BQuantGroupSize::kN <= NPerBlock) ? NPerBlock / BQuantGroupSize::kN : 1;
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;
// static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize");
// static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= BQuantGroupSize");
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= BQuantGroupSize");
// static_assert(NPerBlock % QuantGroupSize::kN == 0,
// "NPerBlock must be a multiple of QuantGroupSize::kN");
static_assert(KPerBlock % QuantGroupSize::kK == 0,
"KPerBlock must be a multiple of QuantGroupSize::kK");
// static_assert(NPerBlock % BQuantGroupSize::kN == 0,
// "NPerBlock must be a multiple of BQuantGroupSize::kN");
static_assert(KPerBlock % BQuantGroupSize::kK == 0,
"KPerBlock must be a multiple of BQuantGroupSize::kK");
// Create DRAM tile window for BQ
template <typename BQDramBlockWindowTmp>

View File

@@ -43,14 +43,14 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BlockGemmShape = typename Problem::BlockGemmShape;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlockBQ = (Problem::QuantGroupSize::kN <= NPerBlock)
? NPerBlock / Problem::QuantGroupSize::kN
: 1;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlockBQ = (Problem::BQuantGroupSize::kN <= NPerBlock)
? NPerBlock / Problem::BQuantGroupSize::kN
: 1;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
@@ -61,7 +61,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
WarpTile::at(I2),
Problem::TransposeC>;
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
using TileEncodingPattern = tile_distribution_encoding_pattern_bq<
BlockGemmShape,
@@ -72,7 +72,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
Problem::BQuantGroupSize::kN,
Problem::BQuantGroupSize::kK,
BQLayout,
PreshuffleQuant>;
BPreshuffleQuant>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
else

View File

@@ -26,12 +26,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmBQuantPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
@@ -45,7 +45,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
ADataType,
BDataType>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
@@ -66,11 +66,11 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NPerBlockBQ =
(QuantGroupSize::kN <= BlockGemmShape::kN)
? integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN)
(BQuantGroupSize::kN <= BlockGemmShape::kN)
? integer_divide_ceil(BlockGemmShape::kN, BQuantGroupSize::kN)
: 1;
static constexpr index_t KPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
integer_divide_ceil(BlockGemmShape::kK, BQuantGroupSize::kK);
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
@@ -88,7 +88,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
@@ -109,7 +109,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
concat('x', kPadM, kPadN, kPadK), BQuantGroupSize::GetName());
// clang-format on
}
@@ -165,7 +165,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
<< "BQuantGroupSize: " << BQuantGroupSize::GetName() << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
@@ -252,7 +252,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
static_assert(
PreshuffleQuant ||
BPreshuffleQuant ||
(is_bq_row_major
? (KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}])
@@ -304,9 +304,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
const BQDramTileWindowStep bq_dram_tile_window_step =
(PreshuffleQuant)
(BPreshuffleQuant)
? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0)

View File

@@ -52,7 +52,7 @@ template <typename BlockGemmShape,
index_t XPerTile,
index_t KPerBlockAQ,
index_t VecSize,
bool PreshuffleQuant>
bool APreshuffleQuant>
struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding_pattern
{
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
@@ -72,7 +72,7 @@ struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
if constexpr(PreshuffleQuant)
if constexpr(APreshuffleQuant)
{
// # of elements per thread
static_assert(XPerTile >= warp_size && XPerTile % warp_size == 0);
@@ -193,8 +193,8 @@ template <typename BlockGemmShape,
index_t NPerTile,
index_t NPerQ,
index_t KPerQ,
typename BQLayout = tensor_layout::gemm::ColumnMajor,
bool PreshuffleQuant = false>
typename BQLayout = tensor_layout::gemm::ColumnMajor,
bool BPreshuffleQuant = false>
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
{
static constexpr index_t warp_size = get_warp_size();
@@ -212,10 +212,11 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
// Preshuffle only supported for ColumnMajor currently
static_assert(!(PreshuffleQuant && std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>),
"PreshuffleQuant only supported for ColumnMajor BQLayout");
static_assert(
!(BPreshuffleQuant && std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>),
"PreshuffleQuant only supported for ColumnMajor BQLayout");
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
// =============================================================================
// PRE-SHUFFLED BQ SCALE TILE DISTRIBUTION

View File

@@ -12,13 +12,13 @@ namespace ck_tile {
template <typename Problem, typename Policy>
struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
{
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = typename Base::ADataType;
using ALayout = typename Base::ALayout;
using BDataType = typename Base::BDataType;
using BLayout = typename Base::BLayout;
using BlockGemmShape = typename Base::BlockGemmShape;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = typename Base::ADataType;
using ALayout = typename Base::ALayout;
using BDataType = typename Base::BDataType;
using BLayout = typename Base::BLayout;
using BlockGemmShape = typename Base::BlockGemmShape;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
@@ -26,16 +26,16 @@ struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Probl
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NPerBlockBQ = NPerBlock / QuantGroupSize::kN;
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
static constexpr index_t NPerBlockBQ = NPerBlock / BQuantGroupSize::kN;
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;
static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize");
static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= BQuantGroupSize");
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= BQuantGroupSize");
static_assert(NPerBlock % QuantGroupSize::kN == 0,
"NPerBlock must be a multiple of QuantGroupSize::kN");
static_assert(KPerBlock % QuantGroupSize::kK == 0,
"KPerBlock must be a multiple of QuantGroupSize::kK");
static_assert(NPerBlock % BQuantGroupSize::kN == 0,
"NPerBlock must be a multiple of BQuantGroupSize::kN");
static_assert(KPerBlock % BQuantGroupSize::kK == 0,
"KPerBlock must be a multiple of BQuantGroupSize::kK");
// Create DRAM tile window for BQ
template <typename BQDramBlockWindowTmp>

View File

@@ -22,9 +22,9 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
@@ -76,7 +76,7 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KScale = KPerBlock / Problem::QuantGroupSize::kK; // k_scale num //2
constexpr index_t KScale = KPerBlock / Problem::BQuantGroupSize::kK; // k_scale num //2
constexpr index_t VecLoadSize =
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
@@ -109,7 +109,7 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0,
static_assert(Problem::BQuantGroupSize::kK % WarpTile::at(I2) == 0,
"KPerWarpGemm must be a multiple of QuantGroupSize!");
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,

View File

@@ -24,15 +24,15 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmMxFp4PipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BDqDataType = remove_cvref_t<typename Problem::ADataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BDqDataType = remove_cvref_t<typename Problem::ADataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
using I0 = number<0>;
using I1 = number<1>;
@@ -58,8 +58,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN;
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK;
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / BQuantGroupSize::kN;
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / BQuantGroupSize::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
@@ -93,7 +93,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', kPadM, kPadN, kPadK),
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
concat('x', kPadM, kPadN, kPadK), BQuantGroupSize::GetName());
// clang-format on
}
@@ -149,7 +149,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
<< "BQuantGroupSize: " << BQuantGroupSize::GetName() << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
@@ -412,7 +412,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock / 2, 0) : make_array(0, KPerBlock / 2);
constexpr index_t b_scale_dram_tile_window_step = KPerBlock / QuantGroupSize::kK;
constexpr index_t b_scale_dram_tile_window_step = KPerBlock / BQuantGroupSize::kK;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start

View File

@@ -120,7 +120,7 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
typename QuantGroupSize_,
typename AQuantGroupSize_,
bool TransposeC_,
typename ComputeDataType_ = BDataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
@@ -133,7 +133,7 @@ using GemmAQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
CDataType_,
BlockGemmShape_,
Traits_,
QuantGroupSize_,
AQuantGroupSize_,
void,
TransposeC_,
ComputeDataType_,
@@ -147,7 +147,7 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
typename QuantGroupSize_,
typename BQuantGroupSize_,
typename ComputeDataType_ = ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
@@ -160,7 +160,7 @@ using GemmBQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
BlockGemmShape_,
Traits_,
void,
QuantGroupSize_,
BQuantGroupSize_,
false, // no TransposeC
ComputeDataType_,
Scheduler_,

View File

@@ -25,7 +25,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
@@ -69,14 +69,14 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
using Base::m_preload;
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
static constexpr index_t VectorLoadSize = Problem::VectorLoadSize;
static constexpr index_t NPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN);
integer_divide_ceil(BlockGemmShape::kN, BQuantGroupSize::kN);
static constexpr index_t KPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
integer_divide_ceil(BlockGemmShape::kK, BQuantGroupSize::kK);
static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(kKPerBlock, QuantGroupSize::kK);
integer_divide_ceil(kKPerBlock, BQuantGroupSize::kK);
static constexpr index_t GetVectorSizeBQ()
{
@@ -94,7 +94,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', Base::GetVectorSizeA(), Base::GetVectorSizeB(), GetVectorSizeBQ()),
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
concat('x', kPadM, kPadN, kPadK), BQuantGroupSize::GetName());
// clang-format on
}
@@ -115,7 +115,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
// then by vector width to get an approximate number of vector loads.
constexpr index_t BQload_inst = ck_tile::integer_divide_ceil(
ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType),
QuantGroupSize::kK * QuantGroupSize::kK),
BQuantGroupSize::kK * BQuantGroupSize::kK),
VectorLoadSize);
// ToDo: Hardcoded, need to change in future. How many instruction emit per iteration
@@ -360,11 +360,11 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
BQBlockTile bq_block_tile, bq_block_tile_2;
bq_block_tile = load_tile(bq_copy_dram_window);
// move BQ to tile 1
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0});
@@ -437,11 +437,11 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
bq_block_tile_2 = load_tile(bq_copy_dram_window);
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0});
@@ -474,11 +474,11 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
bq_block_tile = load_tile(bq_copy_dram_window);
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0});

View File

@@ -33,7 +33,8 @@ inline std::string quant_type_to_string(QuantType quant_type)
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
bool PreshuffleQuant_,
bool APreshuffleQuant_,
bool BPreshuffleQuant_,
bool PreshuffleB_,
typename ALayout_,
typename BLayout_,
@@ -71,8 +72,9 @@ struct TileGemmQuantTraits
static constexpr index_t NumWaveGroups = 1;
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
static constexpr bool PreshuffleB = PreshuffleB_;
static constexpr bool APreshuffleQuant = APreshuffleQuant_;
static constexpr bool BPreshuffleQuant = BPreshuffleQuant_;
static constexpr bool PreshuffleB = PreshuffleB_;
};
} // namespace ck_tile