mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_Tile] Adding support for preshuffleQuant in AB quant Block Scale Gemm (#3629)
* initial commit * preshuffleQuant support for ABQuant * fix mxfp4 to use correct QuantGroupSize * addressing review comments and seperated Preshufflequant for A and B * updated grouped gemm example for updated traits definition * fix for CI failure * updated grouped_gemm_abquant test for updated traits definition * updated grouped_gemm_abquant test for updated traits definition
This commit is contained in:
@@ -96,9 +96,9 @@ struct AQPickerCommon : public BlockGemmQuantBase
|
||||
if constexpr(Traits::TransposeC) // transposed C
|
||||
{
|
||||
index_t reg_offset =
|
||||
Traits::PreshuffleQuant ? mIter : mIter * Traits::AQPerBlock + kQScale;
|
||||
Traits::APreshuffleQuant ? mIter : mIter * Traits::AQPerBlock + kQScale;
|
||||
auto scale_reg = aq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
if constexpr(Traits::PreshuffleQuant)
|
||||
if constexpr(Traits::APreshuffleQuant)
|
||||
{
|
||||
auto pull_from_lane =
|
||||
(__lane_id() & (Traits::WarpGemm::kN - 1)) * Traits::AQPerBlock + kQScale;
|
||||
@@ -121,7 +121,7 @@ struct AQPickerCommon : public BlockGemmQuantBase
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(Traits::PreshuffleQuant)
|
||||
if constexpr(Traits::APreshuffleQuant)
|
||||
{
|
||||
// A view is created on top of the preshuffled AQ, where each row of
|
||||
// the view is composed of a row from a warp tile within an AQ block
|
||||
|
||||
@@ -69,7 +69,8 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
|
||||
@@ -127,9 +128,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
|
||||
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
@@ -162,12 +163,12 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
static constexpr auto MIter_2nd_last =
|
||||
(MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
|
||||
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
integer_divide_ceil(KPerBlock, QuantGroupSize::kK); // 128 / 128 = 1
|
||||
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); // 128 / 128 = 1
|
||||
static constexpr index_t QScalesPerWarpGemmRow =
|
||||
integer_divide_ceil(WG::kK, QuantGroupSize::kK);
|
||||
integer_divide_ceil(WG::kK, BQuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; // 8 / 1 = 8
|
||||
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
|
||||
@@ -289,9 +290,9 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
index_t reg_offset = [&]() {
|
||||
if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN))
|
||||
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
|
||||
{
|
||||
return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ +
|
||||
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN * KPerBlockBQ +
|
||||
kQScale;
|
||||
}
|
||||
else
|
||||
|
||||
@@ -25,9 +25,9 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
|
||||
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
@@ -53,7 +53,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
|
||||
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
static constexpr index_t NIterPerWarp =
|
||||
@@ -63,12 +63,12 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
static constexpr auto MIter_2nd_last =
|
||||
(MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
|
||||
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
|
||||
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerWarpGemmRow =
|
||||
integer_divide_ceil(WG::kK, QuantGroupSize::kK);
|
||||
integer_divide_ceil(WG::kK, BQuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
|
||||
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
|
||||
@@ -173,7 +173,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
constexpr index_t reg_offset = nIter;
|
||||
auto pull_from_lane = (__lane_id() & (WG::kN - 1)) * KPerBlockBQ + kQScale;
|
||||
@@ -205,9 +205,10 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
else
|
||||
{
|
||||
index_t reg_offset = [&]() {
|
||||
if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN))
|
||||
if constexpr(BQuantGroupSize::kN >= (NWarp * WG::kN))
|
||||
{
|
||||
return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ +
|
||||
return (nIter * NWarp * WG::kN) / BQuantGroupSize::kN *
|
||||
KPerBlockBQ +
|
||||
kQScale;
|
||||
}
|
||||
else
|
||||
|
||||
@@ -33,6 +33,7 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
@@ -75,7 +76,8 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
|
||||
@@ -134,8 +136,12 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
|
||||
// BDataType gets converted from PkInt4 during loading
|
||||
using OverrideBDataType =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using OverrideBDataType = std::conditional_t<
|
||||
std::is_same_v<BDataType, pk_int4_t> &&
|
||||
std::is_same_v<typename Traits::BLayout, tensor_layout::gemm::RowMajor>,
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
using Base = BlockGemmQuantBase;
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
@@ -156,7 +162,8 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Traits::PreshuffleQuant;
|
||||
static constexpr bool APreshuffleQuant = Traits::APreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Traits::BPreshuffleQuant;
|
||||
|
||||
static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
|
||||
|
||||
@@ -354,11 +361,24 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
|
||||
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
|
||||
aq_block_tensor);
|
||||
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
constexpr index_t reg_offset = nIter;
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::BQuantGroupSize::kN >
|
||||
(NWarp * WarpGemm::kN) &&
|
||||
Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN)
|
||||
{
|
||||
return kQScale;
|
||||
}
|
||||
else
|
||||
{
|
||||
return nIter;
|
||||
}
|
||||
}();
|
||||
|
||||
auto pull_from_lane =
|
||||
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
|
||||
|
||||
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
|
||||
// cross lane ops
|
||||
uint32_t scale_reg_dword;
|
||||
|
||||
@@ -34,7 +34,7 @@ struct AQuantBlockUniversalGemmAsBsCr
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
|
||||
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
@@ -43,7 +43,7 @@ struct AQuantBlockUniversalGemmAsBsCr
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t AQPerBlock = KPerBlock / QuantGroupSize::kK;
|
||||
static constexpr index_t AQPerBlock = KPerBlock / AQuantGroupSize::kK;
|
||||
|
||||
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
@@ -69,20 +69,20 @@ struct AQuantBlockUniversalGemmAsBsCr
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
|
||||
integer_divide_ceil(KPerBlock, AQuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerWarpGemmRow =
|
||||
integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);
|
||||
integer_divide_ceil(WarpGemm::kK, AQuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
|
||||
|
||||
static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
|
||||
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
|
||||
static_assert(AQuantGroupSize::kK % WarpGemm::kK == 0,
|
||||
"Error! WarpGemm::kK should be a multiple of AQuantGroupSize");
|
||||
static_assert(QScalesPerWarpGemmRow == 1,
|
||||
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
|
||||
"Error! AQuantGroupSize shouldn't be smaller than WarpGemm::kK");
|
||||
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
|
||||
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
|
||||
|
||||
static_assert(KPerBlock / QuantGroupSize::kK > 0,
|
||||
static_assert(KPerBlock / AQuantGroupSize::kK > 0,
|
||||
"Error! Each row of blockgemm should have a separate scale");
|
||||
|
||||
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
|
||||
@@ -110,8 +110,8 @@ struct AQuantBlockUniversalGemmAsBsCr
|
||||
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool TransposeC = Problem::TransposeC;
|
||||
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
|
||||
static constexpr bool TransposeC = Problem::TransposeC;
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
@@ -36,7 +36,7 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
@@ -46,8 +46,8 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NQPerBlock = NPerBlock / QuantGroupSize::kN;
|
||||
static constexpr index_t KQPerBlock = KPerBlock / QuantGroupSize::kK;
|
||||
static constexpr index_t NQPerBlock = NPerBlock / BQuantGroupSize::kN;
|
||||
static constexpr index_t KQPerBlock = KPerBlock / BQuantGroupSize::kK;
|
||||
|
||||
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
@@ -72,23 +72,23 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
|
||||
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
|
||||
integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerWarpGemmRow =
|
||||
integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);
|
||||
integer_divide_ceil(WarpGemm::kK, BQuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
|
||||
|
||||
static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
|
||||
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
|
||||
static_assert(BQuantGroupSize::kK % WarpGemm::kK == 0,
|
||||
"Error! WarpGemm::kK should be a multiple of BQuantGroupSize");
|
||||
static_assert(QScalesPerWarpGemmRow == 1,
|
||||
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
|
||||
"Error! BQuantGroupSize shouldn't be smaller than WarpGemm::kK");
|
||||
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
|
||||
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
|
||||
|
||||
static_assert(KPerBlock / QuantGroupSize::kK > 0,
|
||||
static_assert(KPerBlock / BQuantGroupSize::kK > 0,
|
||||
"Error! Each row of blockgemm should have a separate scale");
|
||||
|
||||
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
|
||||
@@ -153,7 +153,7 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Traits::PreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Traits::BPreshuffleQuant;
|
||||
|
||||
static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
|
||||
|
||||
@@ -317,25 +317,21 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
// constexpr index_t reg_offset = nIter;
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::QuantGroupSize::kN >
|
||||
(NWarp * WarpGemm::kN))
|
||||
if constexpr(GemmTraits::BQuantGroupSize::kN >
|
||||
(NWarp * WarpGemm::kN) &&
|
||||
Traits::NPerBlock == GemmTraits::BQuantGroupSize::kN)
|
||||
{
|
||||
if constexpr(Traits::NPerBlock ==
|
||||
GemmTraits::QuantGroupSize::kN)
|
||||
return kQScale;
|
||||
else
|
||||
return nIter; // for prefill needs kQscale, for decode needs
|
||||
// nIter
|
||||
return kQScale; // prefill: one quant group per block
|
||||
}
|
||||
else
|
||||
{
|
||||
return nIter;
|
||||
return nIter; // decode or multiple groups per warp
|
||||
}
|
||||
}();
|
||||
|
||||
auto pull_from_lane =
|
||||
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
|
||||
|
||||
@@ -370,10 +366,11 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
{
|
||||
// Multiply bquant with accumulated C
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::QuantGroupSize::kN >=
|
||||
if constexpr(GemmTraits::BQuantGroupSize::kN >=
|
||||
(NWarp * WarpGemm::kN))
|
||||
return (nIter * NWarp * WarpGemm::kN) /
|
||||
GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock +
|
||||
GemmTraits::BQuantGroupSize::kN *
|
||||
Traits::KQPerBlock +
|
||||
kQScale;
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user