mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[CK_Tile] Adding support for preshuffleQuant in AB quant Block Scale Gemm (#3629)
* initial commit * preshuffleQuant support for ABQuant * fix mxfp4 to use correct QuantGroupSize * addressing review comments and seperated Preshufflequant for A and B * updated grouped gemm example for updated traits definition * fix for CI failure * updated grouped_gemm_abquant test for updated traits definition * updated grouped_gemm_abquant test for updated traits definition
This commit is contained in:
@@ -96,9 +96,9 @@ struct AQPickerCommon : public BlockGemmQuantBase
|
||||
if constexpr(Traits::TransposeC) // transposed C
|
||||
{
|
||||
index_t reg_offset =
|
||||
Traits::PreshuffleQuant ? mIter : mIter * Traits::AQPerBlock + kQScale;
|
||||
Traits::APreshuffleQuant ? mIter : mIter * Traits::AQPerBlock + kQScale;
|
||||
auto scale_reg = aq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
if constexpr(Traits::PreshuffleQuant)
|
||||
if constexpr(Traits::APreshuffleQuant)
|
||||
{
|
||||
auto pull_from_lane =
|
||||
(__lane_id() & (Traits::WarpGemm::kN - 1)) * Traits::AQPerBlock + kQScale;
|
||||
@@ -121,7 +121,7 @@ struct AQPickerCommon : public BlockGemmQuantBase
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(Traits::PreshuffleQuant)
|
||||
if constexpr(Traits::APreshuffleQuant)
|
||||
{
|
||||
// A view is created on top of the preshuffled AQ, where each row of
|
||||
// the view is composed of a row from a warp tile within an AQ block
|
||||
|
||||
@@ -69,7 +69,8 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
|
||||
@@ -127,9 +128,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
|
||||
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
@@ -162,12 +163,12 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
static constexpr auto MIter_2nd_last =
|
||||
(MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
|
||||
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
integer_divide_ceil(KPerBlock, QuantGroupSize::kK); // 128 / 128 = 1
|
||||
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); // 128 / 128 = 1
|
||||
static constexpr index_t QScalesPerWarpGemmRow =
|
||||
integer_divide_ceil(WG::kK, QuantGroupSize::kK);
|
||||
integer_divide_ceil(WG::kK, BQuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; // 8 / 1 = 8
|
||||
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
|
||||
@@ -289,9 +290,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
index_t reg_offset = [&]() {
|
||||
if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN))
|
||||
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
|
||||
{
|
||||
return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ +
|
||||
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN * KPerBlockBQ +
|
||||
kQScale;
|
||||
}
|
||||
else
|
||||
|
||||
@@ -25,9 +25,9 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
|
||||
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
@@ -53,7 +53,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
|
||||
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
static constexpr index_t NIterPerWarp =
|
||||
@@ -63,12 +63,12 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
static constexpr auto MIter_2nd_last =
|
||||
(MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
|
||||
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
|
||||
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerWarpGemmRow =
|
||||
integer_divide_ceil(WG::kK, QuantGroupSize::kK);
|
||||
integer_divide_ceil(WG::kK, BQuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
|
||||
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
|
||||
@@ -173,7 +173,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
constexpr index_t reg_offset = nIter;
|
||||
auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale;
|
||||
@@ -205,9 +205,10 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
else
|
||||
{
|
||||
index_t reg_offset = [&]() {
|
||||
if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN))
|
||||
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
|
||||
{
|
||||
return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ +
|
||||
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN *
|
||||
KPerBlockBQ +
|
||||
kQScale;
|
||||
}
|
||||
else
|
||||
|
||||
@@ -33,6 +33,7 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
@@ -75,7 +76,8 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
|
||||
@@ -134,8 +136,12 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
|
||||
// BDataType gets converted from PkInt4 during loading
|
||||
using OverrideBDataType =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using OverrideBDataType = std::conditional_t<
|
||||
std::is_same_v<BDataType, pk_int4_t> &&
|
||||
std::is_same_v<typename Traits::BLayout, tensor_layout::gemm::RowMajor>,
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
using Base = BlockGemmQuantBase;
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
@@ -156,7 +162,8 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Traits::PreshuffleQuant;
|
||||
static constexpr bool APreshuffleQuant = Traits::APreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Traits::BPreshuffleQuant;
|
||||
|
||||
static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
|
||||
|
||||
@@ -354,11 +361,24 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
|
||||
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
|
||||
aq_block_tensor);
|
||||
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
constexpr index_t reg_offset = nIter;
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::BQuantGroupSize::kN >
|
||||
(NWarp * WarpGemm::kN) &&
|
||||
Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN)
|
||||
{
|
||||
return kQScale;
|
||||
}
|
||||
else
|
||||
{
|
||||
return nIter;
|
||||
}
|
||||
}();
|
||||
|
||||
auto pull_from_lane =
|
||||
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
|
||||
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
// cross lane ops
|
||||
uint32_t scale_reg_dword;
|
||||
|
||||
@@ -34,7 +34,7 @@ struct AQuantBlockUniversalGemmAsBsCr
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
|
||||
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
@@ -43,7 +43,7 @@ struct AQuantBlockUniversalGemmAsBsCr
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t AQPerBlock = KPerBlock / QuantGroupSize::kK;
|
||||
static constexpr index_t AQPerBlock = KPerBlock / AQuantGroupSize::kK;
|
||||
|
||||
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
@@ -69,20 +69,20 @@ struct AQuantBlockUniversalGemmAsBsCr
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
|
||||
integer_divide_ceil(KPerBlock, AQuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerWarpGemmRow =
|
||||
integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);
|
||||
integer_divide_ceil(WarpGemm::kK, AQuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
|
||||
|
||||
static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
|
||||
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
|
||||
static_assert(AQuantGroupSize::kK % WarpGemm::kK == 0,
|
||||
"Error! WarpGemm::kK should be a multiple of AQuantGroupSize");
|
||||
static_assert(QScalesPerWarpGemmRow == 1,
|
||||
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
|
||||
"Error! AQuantGroupSize shouldn't be smaller than WarpGemm::kK");
|
||||
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
|
||||
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
|
||||
|
||||
static_assert(KPerBlock / QuantGroupSize::kK > 0,
|
||||
static_assert(KPerBlock / AQuantGroupSize::kK > 0,
|
||||
"Error! Each row of blockgemm should have a separate scale");
|
||||
|
||||
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
|
||||
@@ -110,8 +110,8 @@ struct AQuantBlockUniversalGemmAsBsCr
|
||||
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool TransposeC = Problem::TransposeC;
|
||||
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
|
||||
static constexpr bool TransposeC = Problem::TransposeC;
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
@@ -36,7 +36,7 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
@@ -46,8 +46,8 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NQPerBlock = NPerBlock / QuantGroupSize::kN;
|
||||
static constexpr index_t KQPerBlock = KPerBlock / QuantGroupSize::kK;
|
||||
static constexpr index_t NQPerBlock = NPerBlock / BQuantGroupSize::kN;
|
||||
static constexpr index_t KQPerBlock = KPerBlock / BQuantGroupSize::kK;
|
||||
|
||||
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
@@ -72,23 +72,23 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
|
||||
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerWarpGemmRow =
|
||||
integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);
|
||||
integer_divide_ceil(WarpGemm::kK, BQuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
|
||||
|
||||
static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
|
||||
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
|
||||
static_assert(BQuantGroupSize::kK % WarpGemm::kK == 0,
|
||||
"Error! WarpGemm::kK should be a multiple of BQuantGroupSize");
|
||||
static_assert(QScalesPerWarpGemmRow == 1,
|
||||
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
|
||||
"Error! BQuantGroupSize shouldn't be smaller than WarpGemm::kK");
|
||||
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
|
||||
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
|
||||
|
||||
static_assert(KPerBlock / QuantGroupSize::kK > 0,
|
||||
static_assert(KPerBlock / BQuantGroupSize::kK > 0,
|
||||
"Error! Each row of blockgemm should have a separate scale");
|
||||
|
||||
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
|
||||
@@ -153,7 +153,7 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Traits::PreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Traits::BPreshuffleQuant;
|
||||
|
||||
static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
|
||||
|
||||
@@ -317,25 +317,21 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
// constexpr index_t reg_offset = nIter;
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::QuantGroupSize::kN >
|
||||
(NWarp * WarpGemm::kN))
|
||||
if constexpr(GemmTraits::BQuantGroupSize::kN >
|
||||
(NWarp * WarpGemm::kN) &&
|
||||
Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN)
|
||||
{
|
||||
if constexpr(Traits::NPerBlock ==
|
||||
GemmTraits::QuantGroupSize::kN)
|
||||
return kQScale;
|
||||
else
|
||||
return nIter; // for prefill needs kQscale, for decode needs
|
||||
// nIter
|
||||
return kQScale; // prefill: one quant group per block
|
||||
}
|
||||
else
|
||||
{
|
||||
return nIter;
|
||||
return nIter; // decode or multiple groups per warp
|
||||
}
|
||||
}();
|
||||
|
||||
auto pull_from_lane =
|
||||
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
|
||||
|
||||
@@ -370,10 +366,11 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
{
|
||||
// Multiply bquant with accumulated C
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::QuantGroupSize::kN >=
|
||||
if constexpr(GemmTraits::BQuantGroupSize::kN >=
|
||||
(NWarp * WarpGemm::kN))
|
||||
return (nIter * NWarp * WarpGemm::kN) /
|
||||
GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock +
|
||||
GemmTraits::BQuantGroupSize::kN *
|
||||
Traits::KQPerBlock +
|
||||
kQScale;
|
||||
else
|
||||
{
|
||||
|
||||
@@ -67,15 +67,27 @@ struct get_bq_data_type_or<T, Default, std::void_t<typename T::BQDataType>>
|
||||
};
|
||||
|
||||
template <typename, typename = void>
|
||||
struct is_quantpreshuffle_enabled
|
||||
struct is_Aquantpreshuffle_enabled
|
||||
{
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_quantpreshuffle_enabled<T, std::void_t<decltype(T::PreshuffleQuant)>>
|
||||
struct is_Aquantpreshuffle_enabled<T, std::void_t<decltype(T::APreshuffleQuant)>>
|
||||
{
|
||||
static constexpr bool value = T::PreshuffleQuant;
|
||||
static constexpr bool value = T::APreshuffleQuant;
|
||||
};
|
||||
|
||||
template <typename, typename = void>
|
||||
struct is_Bquantpreshuffle_enabled
|
||||
{
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_Bquantpreshuffle_enabled<T, std::void_t<decltype(T::BPreshuffleQuant)>>
|
||||
{
|
||||
static constexpr bool value = T::BPreshuffleQuant;
|
||||
};
|
||||
|
||||
template <typename, typename = void>
|
||||
@@ -206,8 +218,10 @@ struct QuantGemmKernel
|
||||
typename detail::get_bq_layout_or<GemmPipeline, typename GemmPipeline::BLayout>::type>;
|
||||
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr bool PreshuffleQuant =
|
||||
detail::is_quantpreshuffle_enabled<GemmPipeline_>::value;
|
||||
static constexpr bool APreshuffleQuant =
|
||||
detail::is_Aquantpreshuffle_enabled<GemmPipeline_>::value;
|
||||
static constexpr bool BPreshuffleQuant =
|
||||
detail::is_Bquantpreshuffle_enabled<GemmPipeline_>::value;
|
||||
static constexpr bool PreshuffleB = detail::is_preshuffleB_enabled<GemmPipeline_>::value;
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
@@ -476,7 +490,7 @@ struct QuantGemmKernel
|
||||
{
|
||||
// Step 1: Create tensor view for AQ
|
||||
const auto& aq_tensor_view = [&]() {
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped && APreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
|
||||
@@ -533,7 +547,7 @@ struct QuantGemmKernel
|
||||
}
|
||||
else if constexpr((kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::ABQuantGrouped) &&
|
||||
!PreshuffleQuant)
|
||||
!APreshuffleQuant)
|
||||
{
|
||||
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -571,13 +585,13 @@ struct QuantGemmKernel
|
||||
|
||||
// Step 2: Create tile window (no padding for AQ)
|
||||
const auto& aq_block_window = [&]() {
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped && APreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
using AQuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / AQuantGroupSize::kK;
|
||||
constexpr auto tile_window_width =
|
||||
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
|
||||
constexpr auto tile_window_height = block_m / warp_m;
|
||||
@@ -587,11 +601,19 @@ struct QuantGemmKernel
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_m_idx * tile_window_height, 0});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
|
||||
else if constexpr((kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::ABQuantGrouped) &&
|
||||
!APreshuffleQuant)
|
||||
{
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
|
||||
using AQuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
|
||||
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / AQuantGroupSize::kK;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>,
|
||||
"ABQuantGrouped requires RowMajor AQ layout");
|
||||
}
|
||||
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(aq_tensor_view,
|
||||
@@ -605,17 +627,6 @@ struct QuantGemmKernel
|
||||
{0, i_m});
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto block_k = TilePartitioner::KPerBlock;
|
||||
return make_tile_window(
|
||||
aq_tensor_view,
|
||||
make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
return make_tile_window(aq_tensor_view,
|
||||
@@ -808,14 +819,15 @@ struct QuantGemmKernel
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped ||
|
||||
kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
|
||||
"PreshuffleQuant with BQuantGrouped currently only supports "
|
||||
"ColumnMajor BQ layout");
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
|
||||
return MakePreshuffledQuantTensorView<
|
||||
GemmPipeline::KPerBlockBQ,
|
||||
@@ -824,48 +836,42 @@ struct QuantGemmKernel
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(I1),
|
||||
GemmPipeline::GetVectorSizeBQ()>(
|
||||
bq_ptr,
|
||||
ck_tile::integer_divide_ceil(kargs.N, QuantGroupSize::kN),
|
||||
QuantGroupSize::kN,
|
||||
ck_tile::integer_divide_ceil(kargs.N, BQuantGroupSize::kN),
|
||||
BQuantGroupSize::kN,
|
||||
kargs.QK_B);
|
||||
}
|
||||
else
|
||||
{
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
|
||||
if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
|
||||
"ABQuantGrouped requires ColumnMajor BQ layout");
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK),
|
||||
integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
|
||||
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1),
|
||||
make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK),
|
||||
integer_divide_ceil(kargs.N, BQuantGroupSize::kN)),
|
||||
make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN),
|
||||
integer_divide_ceil(kargs.K, QuantGroupSize::kK)),
|
||||
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1),
|
||||
make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN),
|
||||
integer_divide_ceil(kargs.K, BQuantGroupSize::kK)),
|
||||
make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK), 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
|
||||
make_tuple(kargs.stride_BQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
@@ -881,28 +887,29 @@ struct QuantGemmKernel
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped ||
|
||||
kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
if constexpr(PreshuffleQuant)
|
||||
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
|
||||
// Number of N-dimension quantization groups per block
|
||||
constexpr auto block_n = (QuantGroupSize::kN <= TilePartitioner::NPerBlock)
|
||||
? TilePartitioner::NPerBlock / QuantGroupSize::kN
|
||||
: QuantGroupSize::kN / TilePartitioner::NPerBlock;
|
||||
constexpr auto block_n = (BQuantGroupSize::kN <= TilePartitioner::NPerBlock)
|
||||
? TilePartitioner::NPerBlock / BQuantGroupSize::kN
|
||||
: BQuantGroupSize::kN / TilePartitioner::NPerBlock;
|
||||
|
||||
// Number of N-dimension elements per warp
|
||||
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
// Determine how many warps share the same scale in N-dimension
|
||||
constexpr auto warp_per_group = (QuantGroupSize::kN < warp_n)
|
||||
? (warp_n / QuantGroupSize::kN)
|
||||
: (QuantGroupSize::kN / warp_n);
|
||||
constexpr auto warp_per_group = (BQuantGroupSize::kN < warp_n)
|
||||
? (warp_n / BQuantGroupSize::kN)
|
||||
: (BQuantGroupSize::kN / warp_n);
|
||||
|
||||
// Number of K-dimension quantization groups per block
|
||||
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / BQuantGroupSize::kK;
|
||||
|
||||
// The pre-shuffled layout flattens warp_n ×
|
||||
// bqk_per_block scales per row, Padded up to warp_size
|
||||
@@ -911,25 +918,25 @@ struct QuantGemmKernel
|
||||
ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
|
||||
|
||||
// Adapts based on fine vs coarse quantization granularity:
|
||||
// - Fine-grained (QuantGroupSize::kN < warp_n):
|
||||
// - Fine-grained (BQuantGroupSize::kN < warp_n):
|
||||
// Multiple quant groups per warp → fewer rows needed per block.
|
||||
// height = block_n / warp_per_group
|
||||
//
|
||||
// - Coarse-grained (QuantGroupSize::kN >= warp_n):
|
||||
// - Coarse-grained (BQuantGroupSize::kN >= warp_n):
|
||||
// Each row represents one quant group.
|
||||
// height = block_n
|
||||
constexpr auto tile_window_height =
|
||||
(QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n;
|
||||
(BQuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n;
|
||||
|
||||
auto block_n_idx = i_n / TilePartitioner::NPerBlock;
|
||||
|
||||
// For decode shapes GN: 128, Blocks needs to repeat 0,0,1,1,2,2 ...
|
||||
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
|
||||
if(BQuantGroupSize::kN > TilePartitioner::NPerBlock)
|
||||
{
|
||||
block_n_idx = block_n_idx >> 1;
|
||||
}
|
||||
|
||||
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
|
||||
if(BQuantGroupSize::kN > TilePartitioner::NPerBlock)
|
||||
{
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
@@ -946,17 +953,22 @@ struct QuantGemmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
|
||||
"ABQuantGrouped requires RowMajor AQ layout");
|
||||
}
|
||||
constexpr auto tensor_dim =
|
||||
(QuantGroupSize::kN <= TilePartitioner::NPerBlock)
|
||||
? TilePartitioner::NPerBlock / QuantGroupSize::kN
|
||||
(BQuantGroupSize::kN <= TilePartitioner::NPerBlock)
|
||||
? TilePartitioner::NPerBlock / BQuantGroupSize::kN
|
||||
: 1;
|
||||
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
|
||||
make_tuple(number<TilePartitioner::KPerBlock / BQuantGroupSize::kK>{},
|
||||
number<tensor_dim>{}),
|
||||
{0, i_n / QuantGroupSize::kN});
|
||||
{0, i_n / BQuantGroupSize::kN});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -964,21 +976,11 @@ struct QuantGemmKernel
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<tensor_dim>{},
|
||||
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
|
||||
{i_n / QuantGroupSize::kN, 0});
|
||||
number<TilePartitioner::KPerBlock / BQuantGroupSize::kK>{}),
|
||||
{i_n / BQuantGroupSize::kN, 0});
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
return make_tile_window(
|
||||
bq_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
|
||||
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
|
||||
{i_n / QuantGroupSize::kN, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
@@ -1223,7 +1225,7 @@ struct QuantGemmKernel
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped)
|
||||
{
|
||||
index_t m = 0;
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(APreshuffleQuant)
|
||||
{
|
||||
m = kargs.M;
|
||||
}
|
||||
@@ -1233,7 +1235,7 @@ struct QuantGemmKernel
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
index_t n = 0;
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
n = kargs.N;
|
||||
}
|
||||
@@ -1244,9 +1246,9 @@ struct QuantGemmKernel
|
||||
{
|
||||
index_t m = 0;
|
||||
index_t n = 0;
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
m = kargs.M;
|
||||
// m = kargs.M;
|
||||
n = kargs.N;
|
||||
}
|
||||
return GemmPipeline{}.template operator()(a_block_window,
|
||||
|
||||
@@ -72,7 +72,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / AQuantGroupSize::kK;
|
||||
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / BQuantGroupSize::kN;
|
||||
static constexpr index_t NPerBlockBQ =
|
||||
(BQuantGroupSize::kN <= BlockGemmShape::kN)
|
||||
? integer_divide_ceil(BlockGemmShape::kN, BQuantGroupSize::kN)
|
||||
: 1;
|
||||
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / BQuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
@@ -95,7 +98,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
@@ -264,7 +268,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
|
||||
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
static_assert(
|
||||
PreshuffleQuant ||
|
||||
BPreshuffleQuant ||
|
||||
(is_bq_row_major
|
||||
? (KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
@@ -323,15 +327,18 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
// only row_major for AQ
|
||||
const AQDramTileWindowStep aq_dram_tile_window_step =
|
||||
PreshuffleQuant
|
||||
APreshuffleQuant
|
||||
? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
|
||||
BlockGemm::WarpGemm::kM,
|
||||
0)
|
||||
: (is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ));
|
||||
const BQDramTileWindowStep bq_dram_tile_window_step =
|
||||
(PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{}),
|
||||
0)
|
||||
(BPreshuffleQuant)
|
||||
? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, NPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
0)
|
||||
: is_bq_row_major ? make_array(KPerBlockBQ, 0)
|
||||
: make_array(0, KPerBlockBQ);
|
||||
|
||||
@@ -484,7 +491,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
|
||||
|
||||
currIdx = (currIdx + 1) % 2;
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
@@ -495,7 +502,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
// Note: BDataType gets converted during loading from PkInt4
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(
|
||||
|
||||
@@ -12,21 +12,21 @@ namespace ck_tile {
|
||||
template <typename Problem, typename Policy>
|
||||
struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
|
||||
{
|
||||
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
using ADataType = typename Base::ADataType;
|
||||
using ALayout = typename Base::ALayout;
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
|
||||
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
using ADataType = typename Base::ADataType;
|
||||
using ALayout = typename Base::ALayout;
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize::kK;
|
||||
static constexpr index_t KPerBlockAQ = KPerBlock / AQuantGroupSize::kK;
|
||||
|
||||
static_assert(KPerBlock % QuantGroupSize::kK == 0,
|
||||
static_assert(KPerBlock % AQuantGroupSize::kK == 0,
|
||||
"KPerBlock must be a multiple of QuantGroupSize");
|
||||
|
||||
// Create DRAM tile window for AQ
|
||||
|
||||
@@ -23,19 +23,19 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
using Base = BaseGemmPipelineAgBgCrMem<Problem>;
|
||||
using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
|
||||
// When ADataType is pk_int4_t, use BDataType instead for transpose operations
|
||||
// since packed 4-bit integers cannot be directly transposed (requires at least 8-bit precision)
|
||||
using OverrideADataType =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
|
||||
static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
|
||||
static_assert(AQuantGroupSize::kM == 1, "no block for M supported yet!");
|
||||
static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
@@ -60,7 +60,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize::kK;
|
||||
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / AQuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
@@ -78,7 +78,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
@@ -99,7 +99,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName(),
|
||||
concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(),
|
||||
Scheduler == GemmPipelineScheduler::Interwave ? "interwave" : "intrawave"); // else Intrawave
|
||||
// clang-format on
|
||||
}
|
||||
@@ -156,7 +156,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
|
||||
<< "AQuantGroupSize: " << AQuantGroupSize::GetName() << "\n"
|
||||
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
|
||||
<< "PrefetchStages: " << PrefetchStages << "\n";
|
||||
return str.str();
|
||||
@@ -216,7 +216,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
std::is_same_v<AQLayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(!PreshuffleQuant, "Memory pipeline does not support PreshuffleQuant!");
|
||||
static_assert(!APreshuffleQuant, "Memory pipeline does not support APreshuffleQuant!");
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
|
||||
@@ -32,22 +32,22 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
|
||||
using BlockGemmShape = typename Problem::BlockGemmShape;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::AQuantGroupSize::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
|
||||
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::AQuantGroupSize::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
|
||||
constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(APreshuffleQuant)
|
||||
{
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_aq<
|
||||
BlockGemmShape,
|
||||
@@ -57,7 +57,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
ck_tile::integer_least_multiple(WarpGemm::kM * KPerBlockAQ, get_warp_size()),
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
PreshuffleQuant>;
|
||||
APreshuffleQuant>;
|
||||
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
@@ -89,7 +89,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
KPerBlockAQ,
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
PreshuffleQuant>;
|
||||
APreshuffleQuant>;
|
||||
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
@@ -103,7 +103,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
MPerBlock, // XPerTile
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
PreshuffleQuant>;
|
||||
APreshuffleQuant>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution_transposed();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,19 +20,19 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
|
||||
// When ADataType is pk_int4_t, use BDataType instead for transpose operations
|
||||
// since packed 4-bit integers cannot be directly transposed (requires at least 8-bit precision)
|
||||
using OverrideADataType =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
|
||||
static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
|
||||
static_assert(AQuantGroupSize::kM == 1, "no block for M supported yet!");
|
||||
static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
@@ -57,7 +57,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize::kK;
|
||||
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / AQuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
@@ -75,7 +75,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
@@ -96,7 +96,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
|
||||
concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -152,7 +152,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
|
||||
<< "AQuantGroupSize: " << AQuantGroupSize::GetName() << "\n"
|
||||
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
|
||||
<< "PrefetchStages: " << PrefetchStages << "\n";
|
||||
return str.str();
|
||||
@@ -271,7 +271,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
|
||||
// only row_major for AQ
|
||||
const AQDramTileWindowStep aq_dram_tile_window_step =
|
||||
PreshuffleQuant
|
||||
APreshuffleQuant
|
||||
? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
|
||||
BlockGemm::WarpGemm::kM,
|
||||
0)
|
||||
|
||||
@@ -12,13 +12,13 @@ namespace ck_tile {
|
||||
template <typename Problem, typename Policy>
|
||||
struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
|
||||
{
|
||||
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
using ADataType = typename Base::ADataType;
|
||||
using ALayout = typename Base::ALayout;
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
using ADataType = typename Base::ADataType;
|
||||
using ALayout = typename Base::ALayout;
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
|
||||
@@ -27,16 +27,16 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ =
|
||||
(QuantGroupSize::kN <= NPerBlock) ? NPerBlock / QuantGroupSize::kN : 1;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
|
||||
(BQuantGroupSize::kN <= NPerBlock) ? NPerBlock / BQuantGroupSize::kN : 1;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;
|
||||
|
||||
// static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
|
||||
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize");
|
||||
// static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= BQuantGroupSize");
|
||||
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= BQuantGroupSize");
|
||||
|
||||
// static_assert(NPerBlock % QuantGroupSize::kN == 0,
|
||||
// "NPerBlock must be a multiple of QuantGroupSize::kN");
|
||||
static_assert(KPerBlock % QuantGroupSize::kK == 0,
|
||||
"KPerBlock must be a multiple of QuantGroupSize::kK");
|
||||
// static_assert(NPerBlock % BQuantGroupSize::kN == 0,
|
||||
// "NPerBlock must be a multiple of BQuantGroupSize::kN");
|
||||
static_assert(KPerBlock % BQuantGroupSize::kK == 0,
|
||||
"KPerBlock must be a multiple of BQuantGroupSize::kK");
|
||||
|
||||
// Create DRAM tile window for BQ
|
||||
template <typename BQDramBlockWindowTmp>
|
||||
|
||||
@@ -43,14 +43,14 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using BlockGemmShape = typename Problem::BlockGemmShape;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = (Problem::QuantGroupSize::kN <= NPerBlock)
|
||||
? NPerBlock / Problem::QuantGroupSize::kN
|
||||
: 1;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
|
||||
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = (Problem::BQuantGroupSize::kN <= NPerBlock)
|
||||
? NPerBlock / Problem::BQuantGroupSize::kN
|
||||
: 1;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
|
||||
constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
|
||||
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
@@ -61,7 +61,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_bq<
|
||||
BlockGemmShape,
|
||||
@@ -72,7 +72,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
Problem::BQuantGroupSize::kN,
|
||||
Problem::BQuantGroupSize::kK,
|
||||
BQLayout,
|
||||
PreshuffleQuant>;
|
||||
BPreshuffleQuant>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
else
|
||||
|
||||
@@ -26,12 +26,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using PipelineImplBase = GemmBQuantPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
@@ -45,7 +45,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
|
||||
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
@@ -66,11 +66,11 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ =
|
||||
(QuantGroupSize::kN <= BlockGemmShape::kN)
|
||||
? integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN)
|
||||
(BQuantGroupSize::kN <= BlockGemmShape::kN)
|
||||
? integer_divide_ceil(BlockGemmShape::kN, BQuantGroupSize::kN)
|
||||
: 1;
|
||||
static constexpr index_t KPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
|
||||
integer_divide_ceil(BlockGemmShape::kK, BQuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
@@ -88,7 +88,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
@@ -109,7 +109,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
|
||||
concat('x', kPadM, kPadN, kPadK), BQuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -165,7 +165,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
|
||||
<< "BQuantGroupSize: " << BQuantGroupSize::GetName() << "\n"
|
||||
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
|
||||
<< "PrefetchStages: " << PrefetchStages << "\n";
|
||||
return str.str();
|
||||
@@ -252,7 +252,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
static_assert(
|
||||
PreshuffleQuant ||
|
||||
BPreshuffleQuant ||
|
||||
(is_bq_row_major
|
||||
? (KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
@@ -304,9 +304,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
const BQDramTileWindowStep bq_dram_tile_window_step =
|
||||
(PreshuffleQuant)
|
||||
(BPreshuffleQuant)
|
||||
? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, NPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
0)
|
||||
|
||||
@@ -52,7 +52,7 @@ template <typename BlockGemmShape,
|
||||
index_t XPerTile,
|
||||
index_t KPerBlockAQ,
|
||||
index_t VecSize,
|
||||
bool PreshuffleQuant>
|
||||
bool APreshuffleQuant>
|
||||
struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding_pattern
|
||||
{
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
@@ -72,7 +72,7 @@ struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(APreshuffleQuant)
|
||||
{
|
||||
// # of elements per thread
|
||||
static_assert(XPerTile >= warp_size && XPerTile % warp_size == 0);
|
||||
@@ -193,8 +193,8 @@ template <typename BlockGemmShape,
|
||||
index_t NPerTile,
|
||||
index_t NPerQ,
|
||||
index_t KPerQ,
|
||||
typename BQLayout = tensor_layout::gemm::ColumnMajor,
|
||||
bool PreshuffleQuant = false>
|
||||
typename BQLayout = tensor_layout::gemm::ColumnMajor,
|
||||
bool BPreshuffleQuant = false>
|
||||
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
|
||||
{
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
@@ -212,10 +212,11 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
// Preshuffle only supported for ColumnMajor currently
|
||||
static_assert(!(PreshuffleQuant && std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>),
|
||||
"PreshuffleQuant only supported for ColumnMajor BQLayout");
|
||||
static_assert(
|
||||
!(BPreshuffleQuant && std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>),
|
||||
"PreshuffleQuant only supported for ColumnMajor BQLayout");
|
||||
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
// =============================================================================
|
||||
// PRE-SHUFFLED BQ SCALE TILE DISTRIBUTION
|
||||
|
||||
@@ -12,13 +12,13 @@ namespace ck_tile {
|
||||
template <typename Problem, typename Policy>
|
||||
struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
|
||||
{
|
||||
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
using ADataType = typename Base::ADataType;
|
||||
using ALayout = typename Base::ALayout;
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
using ADataType = typename Base::ADataType;
|
||||
using ALayout = typename Base::ALayout;
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
|
||||
@@ -26,16 +26,16 @@ struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Probl
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ = NPerBlock / QuantGroupSize::kN;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
|
||||
static constexpr index_t NPerBlockBQ = NPerBlock / BQuantGroupSize::kN;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;
|
||||
|
||||
static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
|
||||
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize");
|
||||
static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= BQuantGroupSize");
|
||||
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= BQuantGroupSize");
|
||||
|
||||
static_assert(NPerBlock % QuantGroupSize::kN == 0,
|
||||
"NPerBlock must be a multiple of QuantGroupSize::kN");
|
||||
static_assert(KPerBlock % QuantGroupSize::kK == 0,
|
||||
"KPerBlock must be a multiple of QuantGroupSize::kK");
|
||||
static_assert(NPerBlock % BQuantGroupSize::kN == 0,
|
||||
"NPerBlock must be a multiple of BQuantGroupSize::kN");
|
||||
static_assert(KPerBlock % BQuantGroupSize::kK == 0,
|
||||
"KPerBlock must be a multiple of BQuantGroupSize::kK");
|
||||
|
||||
// Create DRAM tile window for BQ
|
||||
template <typename BQDramBlockWindowTmp>
|
||||
|
||||
@@ -22,9 +22,9 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
|
||||
|
||||
static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
|
||||
@@ -76,7 +76,7 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t KScale = KPerBlock / Problem::QuantGroupSize::kK; // k_scale num //2
|
||||
constexpr index_t KScale = KPerBlock / Problem::BQuantGroupSize::kK; // k_scale num //2
|
||||
constexpr index_t VecLoadSize =
|
||||
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
@@ -109,7 +109,7 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0,
|
||||
static_assert(Problem::BQuantGroupSize::kK % WarpTile::at(I2) == 0,
|
||||
"KPerWarpGemm must be a multiple of QuantGroupSize!");
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
|
||||
@@ -24,15 +24,15 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
|
||||
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using PipelineImplBase = GemmMxFp4PipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BDqDataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BDqDataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
|
||||
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
@@ -58,8 +58,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN;
|
||||
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK;
|
||||
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / BQuantGroupSize::kN;
|
||||
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / BQuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
@@ -93,7 +93,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
|
||||
concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
|
||||
concat('x', kPadM, kPadN, kPadK), BQuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -149,7 +149,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
|
||||
<< "BQuantGroupSize: " << BQuantGroupSize::GetName() << "\n"
|
||||
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
|
||||
<< "PrefetchStages: " << PrefetchStages << "\n";
|
||||
return str.str();
|
||||
@@ -412,7 +412,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock / 2, 0) : make_array(0, KPerBlock / 2);
|
||||
|
||||
constexpr index_t b_scale_dram_tile_window_step = KPerBlock / QuantGroupSize::kK;
|
||||
constexpr index_t b_scale_dram_tile_window_step = KPerBlock / BQuantGroupSize::kK;
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
typename QuantGroupSize_,
|
||||
typename AQuantGroupSize_,
|
||||
bool TransposeC_,
|
||||
typename ComputeDataType_ = BDataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
@@ -133,7 +133,7 @@ using GemmAQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
QuantGroupSize_,
|
||||
AQuantGroupSize_,
|
||||
void,
|
||||
TransposeC_,
|
||||
ComputeDataType_,
|
||||
@@ -147,7 +147,7 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
typename QuantGroupSize_,
|
||||
typename BQuantGroupSize_,
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
@@ -160,7 +160,7 @@ using GemmBQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
void,
|
||||
QuantGroupSize_,
|
||||
BQuantGroupSize_,
|
||||
false, // no TransposeC
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
|
||||
@@ -25,7 +25,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
@@ -69,14 +69,14 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
|
||||
using Base::m_preload;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
|
||||
static constexpr index_t VectorLoadSize = Problem::VectorLoadSize;
|
||||
static constexpr index_t NPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN);
|
||||
integer_divide_ceil(BlockGemmShape::kN, BQuantGroupSize::kN);
|
||||
static constexpr index_t KPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
|
||||
integer_divide_ceil(BlockGemmShape::kK, BQuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
integer_divide_ceil(kKPerBlock, QuantGroupSize::kK);
|
||||
integer_divide_ceil(kKPerBlock, BQuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t GetVectorSizeBQ()
|
||||
{
|
||||
@@ -94,7 +94,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', Base::GetVectorSizeA(), Base::GetVectorSizeB(), GetVectorSizeBQ()),
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
|
||||
concat('x', kPadM, kPadN, kPadK), BQuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -115,7 +115,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
// then by vector width to get an approximate number of vector loads.
|
||||
constexpr index_t BQload_inst = ck_tile::integer_divide_ceil(
|
||||
ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType),
|
||||
QuantGroupSize::kK * QuantGroupSize::kK),
|
||||
BQuantGroupSize::kK * BQuantGroupSize::kK),
|
||||
VectorLoadSize);
|
||||
|
||||
// ToDo: Hardcoded, need to change in future. How many instruction emit per iteration
|
||||
@@ -360,11 +360,11 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
BQBlockTile bq_block_tile, bq_block_tile_2;
|
||||
bq_block_tile = load_tile(bq_copy_dram_window);
|
||||
// move BQ to tile 1
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window,
|
||||
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
0});
|
||||
@@ -437,11 +437,11 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
bq_block_tile_2 = load_tile(bq_copy_dram_window);
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window,
|
||||
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
0});
|
||||
@@ -474,11 +474,11 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
bq_block_tile = load_tile(bq_copy_dram_window);
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window,
|
||||
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
0});
|
||||
|
||||
@@ -33,7 +33,8 @@ inline std::string quant_type_to_string(QuantType quant_type)
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
bool PreshuffleQuant_,
|
||||
bool APreshuffleQuant_,
|
||||
bool BPreshuffleQuant_,
|
||||
bool PreshuffleB_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
@@ -71,8 +72,9 @@ struct TileGemmQuantTraits
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
|
||||
|
||||
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
|
||||
static constexpr bool PreshuffleB = PreshuffleB_;
|
||||
static constexpr bool APreshuffleQuant = APreshuffleQuant_;
|
||||
static constexpr bool BPreshuffleQuant = BPreshuffleQuant_;
|
||||
static constexpr bool PreshuffleB = PreshuffleB_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user