mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[CK_Tile] Adding support for preshuffleQuant in AB quant Block Scale Gemm (#3629)
* initial commit * preshuffleQuant support for ABQuant * fix mxfp4 to use correct QuantGroupSize * addressing review comments and seperated Preshufflequant for A and B * updated grouped gemm example for updated traits definition * fix for CI failure * updated grouped_gemm_abquant test for updated traits definition * updated grouped_gemm_abquant test for updated traits definition
This commit is contained in:
@@ -72,7 +72,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / AQuantGroupSize::kK;
|
||||
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / BQuantGroupSize::kN;
|
||||
static constexpr index_t NPerBlockBQ =
|
||||
(BQuantGroupSize::kN <= BlockGemmShape::kN)
|
||||
? integer_divide_ceil(BlockGemmShape::kN, BQuantGroupSize::kN)
|
||||
: 1;
|
||||
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / BQuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
@@ -95,7 +98,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
@@ -264,7 +268,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
|
||||
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
static_assert(
|
||||
PreshuffleQuant ||
|
||||
BPreshuffleQuant ||
|
||||
(is_bq_row_major
|
||||
? (KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
@@ -323,15 +327,18 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
// only row_major for AQ
|
||||
const AQDramTileWindowStep aq_dram_tile_window_step =
|
||||
PreshuffleQuant
|
||||
APreshuffleQuant
|
||||
? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
|
||||
BlockGemm::WarpGemm::kM,
|
||||
0)
|
||||
: (is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ));
|
||||
const BQDramTileWindowStep bq_dram_tile_window_step =
|
||||
(PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{}),
|
||||
0)
|
||||
(BPreshuffleQuant)
|
||||
? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, NPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
0)
|
||||
: is_bq_row_major ? make_array(KPerBlockBQ, 0)
|
||||
: make_array(0, KPerBlockBQ);
|
||||
|
||||
@@ -484,7 +491,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
|
||||
|
||||
currIdx = (currIdx + 1) % 2;
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
@@ -495,7 +502,7 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
// Note: BDataType gets converted during loading from PkInt4
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(
|
||||
|
||||
@@ -12,21 +12,21 @@ namespace ck_tile {
|
||||
template <typename Problem, typename Policy>
|
||||
struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
|
||||
{
|
||||
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
using ADataType = typename Base::ADataType;
|
||||
using ALayout = typename Base::ALayout;
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
|
||||
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
using ADataType = typename Base::ADataType;
|
||||
using ALayout = typename Base::ALayout;
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t KPerBlockAQ = KPerBlock / QuantGroupSize::kK;
|
||||
static constexpr index_t KPerBlockAQ = KPerBlock / AQuantGroupSize::kK;
|
||||
|
||||
static_assert(KPerBlock % QuantGroupSize::kK == 0,
|
||||
static_assert(KPerBlock % AQuantGroupSize::kK == 0,
|
||||
"KPerBlock must be a multiple of QuantGroupSize");
|
||||
|
||||
// Create DRAM tile window for AQ
|
||||
|
||||
@@ -23,19 +23,19 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
using Base = BaseGemmPipelineAgBgCrMem<Problem>;
|
||||
using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
|
||||
// When ADataType is pk_int4_t, use BDataType instead for transpose operations
|
||||
// since packed 4-bit integers cannot be directly transposed (requires at least 8-bit precision)
|
||||
using OverrideADataType =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
|
||||
static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
|
||||
static_assert(AQuantGroupSize::kM == 1, "no block for M supported yet!");
|
||||
static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
@@ -60,7 +60,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize::kK;
|
||||
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / AQuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
@@ -78,7 +78,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
@@ -99,7 +99,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName(),
|
||||
concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(),
|
||||
Scheduler == GemmPipelineScheduler::Interwave ? "interwave" : "intrawave"); // else Intrawave
|
||||
// clang-format on
|
||||
}
|
||||
@@ -156,7 +156,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
|
||||
<< "AQuantGroupSize: " << AQuantGroupSize::GetName() << "\n"
|
||||
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
|
||||
<< "PrefetchStages: " << PrefetchStages << "\n";
|
||||
return str.str();
|
||||
@@ -216,7 +216,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
std::is_same_v<AQLayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(!PreshuffleQuant, "Memory pipeline does not support PreshuffleQuant!");
|
||||
static_assert(!APreshuffleQuant, "Memory pipeline does not support APreshuffleQuant!");
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
|
||||
@@ -32,22 +32,22 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
using AQLayout = remove_cvref_t<typename Problem::AQLayout>;
|
||||
using BlockGemmShape = typename Problem::BlockGemmShape;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::AQuantGroupSize::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
|
||||
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::AQuantGroupSize::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
|
||||
constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(APreshuffleQuant)
|
||||
{
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_aq<
|
||||
BlockGemmShape,
|
||||
@@ -57,7 +57,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
ck_tile::integer_least_multiple(WarpGemm::kM * KPerBlockAQ, get_warp_size()),
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
PreshuffleQuant>;
|
||||
APreshuffleQuant>;
|
||||
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
@@ -89,7 +89,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
KPerBlockAQ,
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
PreshuffleQuant>;
|
||||
APreshuffleQuant>;
|
||||
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
@@ -103,7 +103,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
MPerBlock, // XPerTile
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
PreshuffleQuant>;
|
||||
APreshuffleQuant>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution_transposed();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,19 +20,19 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using AQuantGroupSize = remove_cvref_t<typename Problem::AQuantGroupSize>;
|
||||
// When ADataType is pk_int4_t, use BDataType instead for transpose operations
|
||||
// since packed 4-bit integers cannot be directly transposed (requires at least 8-bit precision)
|
||||
using OverrideADataType =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
|
||||
static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
|
||||
static_assert(AQuantGroupSize::kM == 1, "no block for M supported yet!");
|
||||
static_assert(AQuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
@@ -57,7 +57,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize::kK;
|
||||
static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / AQuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
@@ -75,7 +75,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
@@ -96,7 +96,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
|
||||
concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -152,7 +152,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
|
||||
<< "AQuantGroupSize: " << AQuantGroupSize::GetName() << "\n"
|
||||
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
|
||||
<< "PrefetchStages: " << PrefetchStages << "\n";
|
||||
return str.str();
|
||||
@@ -271,7 +271,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
|
||||
// only row_major for AQ
|
||||
const AQDramTileWindowStep aq_dram_tile_window_step =
|
||||
PreshuffleQuant
|
||||
APreshuffleQuant
|
||||
? make_array(ck_tile::integer_least_multiple(m, MPerBlock) /
|
||||
BlockGemm::WarpGemm::kM,
|
||||
0)
|
||||
|
||||
@@ -12,13 +12,13 @@ namespace ck_tile {
|
||||
template <typename Problem, typename Policy>
|
||||
struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
|
||||
{
|
||||
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
using ADataType = typename Base::ADataType;
|
||||
using ALayout = typename Base::ALayout;
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
using ADataType = typename Base::ADataType;
|
||||
using ALayout = typename Base::ALayout;
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
|
||||
@@ -27,16 +27,16 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ =
|
||||
(QuantGroupSize::kN <= NPerBlock) ? NPerBlock / QuantGroupSize::kN : 1;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
|
||||
(BQuantGroupSize::kN <= NPerBlock) ? NPerBlock / BQuantGroupSize::kN : 1;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;
|
||||
|
||||
// static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
|
||||
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize");
|
||||
// static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= BQuantGroupSize");
|
||||
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= BQuantGroupSize");
|
||||
|
||||
// static_assert(NPerBlock % QuantGroupSize::kN == 0,
|
||||
// "NPerBlock must be a multiple of QuantGroupSize::kN");
|
||||
static_assert(KPerBlock % QuantGroupSize::kK == 0,
|
||||
"KPerBlock must be a multiple of QuantGroupSize::kK");
|
||||
// static_assert(NPerBlock % BQuantGroupSize::kN == 0,
|
||||
// "NPerBlock must be a multiple of BQuantGroupSize::kN");
|
||||
static_assert(KPerBlock % BQuantGroupSize::kK == 0,
|
||||
"KPerBlock must be a multiple of BQuantGroupSize::kK");
|
||||
|
||||
// Create DRAM tile window for BQ
|
||||
template <typename BQDramBlockWindowTmp>
|
||||
|
||||
@@ -43,14 +43,14 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using BlockGemmShape = typename Problem::BlockGemmShape;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = (Problem::QuantGroupSize::kN <= NPerBlock)
|
||||
? NPerBlock / Problem::QuantGroupSize::kN
|
||||
: 1;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
|
||||
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = (Problem::BQuantGroupSize::kN <= NPerBlock)
|
||||
? NPerBlock / Problem::BQuantGroupSize::kN
|
||||
: 1;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
|
||||
constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
|
||||
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
@@ -61,7 +61,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_bq<
|
||||
BlockGemmShape,
|
||||
@@ -72,7 +72,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
Problem::BQuantGroupSize::kN,
|
||||
Problem::BQuantGroupSize::kK,
|
||||
BQLayout,
|
||||
PreshuffleQuant>;
|
||||
BPreshuffleQuant>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
else
|
||||
|
||||
@@ -26,12 +26,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using PipelineImplBase = GemmBQuantPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
@@ -45,7 +45,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
|
||||
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
@@ -66,11 +66,11 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ =
|
||||
(QuantGroupSize::kN <= BlockGemmShape::kN)
|
||||
? integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN)
|
||||
(BQuantGroupSize::kN <= BlockGemmShape::kN)
|
||||
? integer_divide_ceil(BlockGemmShape::kN, BQuantGroupSize::kN)
|
||||
: 1;
|
||||
static constexpr index_t KPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
|
||||
integer_divide_ceil(BlockGemmShape::kK, BQuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
@@ -88,7 +88,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
@@ -109,7 +109,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
|
||||
concat('x', kPadM, kPadN, kPadK), BQuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -165,7 +165,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
|
||||
<< "BQuantGroupSize: " << BQuantGroupSize::GetName() << "\n"
|
||||
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
|
||||
<< "PrefetchStages: " << PrefetchStages << "\n";
|
||||
return str.str();
|
||||
@@ -252,7 +252,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
static_assert(
|
||||
PreshuffleQuant ||
|
||||
BPreshuffleQuant ||
|
||||
(is_bq_row_major
|
||||
? (KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
@@ -304,9 +304,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
const BQDramTileWindowStep bq_dram_tile_window_step =
|
||||
(PreshuffleQuant)
|
||||
(BPreshuffleQuant)
|
||||
? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, NPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
0)
|
||||
|
||||
@@ -52,7 +52,7 @@ template <typename BlockGemmShape,
|
||||
index_t XPerTile,
|
||||
index_t KPerBlockAQ,
|
||||
index_t VecSize,
|
||||
bool PreshuffleQuant>
|
||||
bool APreshuffleQuant>
|
||||
struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding_pattern
|
||||
{
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
@@ -72,7 +72,7 @@ struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(APreshuffleQuant)
|
||||
{
|
||||
// # of elements per thread
|
||||
static_assert(XPerTile >= warp_size && XPerTile % warp_size == 0);
|
||||
@@ -193,8 +193,8 @@ template <typename BlockGemmShape,
|
||||
index_t NPerTile,
|
||||
index_t NPerQ,
|
||||
index_t KPerQ,
|
||||
typename BQLayout = tensor_layout::gemm::ColumnMajor,
|
||||
bool PreshuffleQuant = false>
|
||||
typename BQLayout = tensor_layout::gemm::ColumnMajor,
|
||||
bool BPreshuffleQuant = false>
|
||||
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
|
||||
{
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
@@ -212,10 +212,11 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
|
||||
{
|
||||
// Preshuffle only supported for ColumnMajor currently
|
||||
static_assert(!(PreshuffleQuant && std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>),
|
||||
"PreshuffleQuant only supported for ColumnMajor BQLayout");
|
||||
static_assert(
|
||||
!(BPreshuffleQuant && std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>),
|
||||
"PreshuffleQuant only supported for ColumnMajor BQLayout");
|
||||
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
// =============================================================================
|
||||
// PRE-SHUFFLED BQ SCALE TILE DISTRIBUTION
|
||||
|
||||
@@ -12,13 +12,13 @@ namespace ck_tile {
|
||||
template <typename Problem, typename Policy>
|
||||
struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
|
||||
{
|
||||
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
using ADataType = typename Base::ADataType;
|
||||
using ALayout = typename Base::ALayout;
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
using ADataType = typename Base::ADataType;
|
||||
using ALayout = typename Base::ALayout;
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
|
||||
@@ -26,16 +26,16 @@ struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Probl
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ = NPerBlock / QuantGroupSize::kN;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
|
||||
static constexpr index_t NPerBlockBQ = NPerBlock / BQuantGroupSize::kN;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / BQuantGroupSize::kK;
|
||||
|
||||
static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
|
||||
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize");
|
||||
static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= BQuantGroupSize");
|
||||
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= BQuantGroupSize");
|
||||
|
||||
static_assert(NPerBlock % QuantGroupSize::kN == 0,
|
||||
"NPerBlock must be a multiple of QuantGroupSize::kN");
|
||||
static_assert(KPerBlock % QuantGroupSize::kK == 0,
|
||||
"KPerBlock must be a multiple of QuantGroupSize::kK");
|
||||
static_assert(NPerBlock % BQuantGroupSize::kN == 0,
|
||||
"NPerBlock must be a multiple of BQuantGroupSize::kN");
|
||||
static_assert(KPerBlock % BQuantGroupSize::kK == 0,
|
||||
"KPerBlock must be a multiple of BQuantGroupSize::kK");
|
||||
|
||||
// Create DRAM tile window for BQ
|
||||
template <typename BQDramBlockWindowTmp>
|
||||
|
||||
@@ -22,9 +22,9 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
|
||||
|
||||
static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
|
||||
@@ -76,7 +76,7 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t KScale = KPerBlock / Problem::QuantGroupSize::kK; // k_scale num //2
|
||||
constexpr index_t KScale = KPerBlock / Problem::BQuantGroupSize::kK; // k_scale num //2
|
||||
constexpr index_t VecLoadSize =
|
||||
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
@@ -109,7 +109,7 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0,
|
||||
static_assert(Problem::BQuantGroupSize::kK % WarpTile::at(I2) == 0,
|
||||
"KPerWarpGemm must be a multiple of QuantGroupSize!");
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
|
||||
@@ -24,15 +24,15 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
|
||||
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using PipelineImplBase = GemmMxFp4PipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BDqDataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BDqDataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
|
||||
static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
@@ -58,8 +58,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN;
|
||||
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK;
|
||||
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / BQuantGroupSize::kN;
|
||||
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / BQuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
@@ -93,7 +93,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
|
||||
concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
|
||||
concat('x', kPadM, kPadN, kPadK), BQuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -149,7 +149,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
|
||||
<< "BQuantGroupSize: " << BQuantGroupSize::GetName() << "\n"
|
||||
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
|
||||
<< "PrefetchStages: " << PrefetchStages << "\n";
|
||||
return str.str();
|
||||
@@ -412,7 +412,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock / 2, 0) : make_array(0, KPerBlock / 2);
|
||||
|
||||
constexpr index_t b_scale_dram_tile_window_step = KPerBlock / QuantGroupSize::kK;
|
||||
constexpr index_t b_scale_dram_tile_window_step = KPerBlock / BQuantGroupSize::kK;
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
typename QuantGroupSize_,
|
||||
typename AQuantGroupSize_,
|
||||
bool TransposeC_,
|
||||
typename ComputeDataType_ = BDataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
@@ -133,7 +133,7 @@ using GemmAQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
QuantGroupSize_,
|
||||
AQuantGroupSize_,
|
||||
void,
|
||||
TransposeC_,
|
||||
ComputeDataType_,
|
||||
@@ -147,7 +147,7 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
typename QuantGroupSize_,
|
||||
typename BQuantGroupSize_,
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
@@ -160,7 +160,7 @@ using GemmBQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
void,
|
||||
QuantGroupSize_,
|
||||
BQuantGroupSize_,
|
||||
false, // no TransposeC
|
||||
ComputeDataType_,
|
||||
Scheduler_,
|
||||
|
||||
@@ -25,7 +25,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
@@ -69,14 +69,14 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
|
||||
using Base::m_preload;
|
||||
|
||||
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
|
||||
static constexpr index_t VectorLoadSize = Problem::VectorLoadSize;
|
||||
static constexpr index_t NPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN);
|
||||
integer_divide_ceil(BlockGemmShape::kN, BQuantGroupSize::kN);
|
||||
static constexpr index_t KPerBlockBQ =
|
||||
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
|
||||
integer_divide_ceil(BlockGemmShape::kK, BQuantGroupSize::kK);
|
||||
static constexpr index_t QScalesPerBlockRow =
|
||||
integer_divide_ceil(kKPerBlock, QuantGroupSize::kK);
|
||||
integer_divide_ceil(kKPerBlock, BQuantGroupSize::kK);
|
||||
|
||||
static constexpr index_t GetVectorSizeBQ()
|
||||
{
|
||||
@@ -94,7 +94,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', Base::GetVectorSizeA(), Base::GetVectorSizeB(), GetVectorSizeBQ()),
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
|
||||
concat('x', kPadM, kPadN, kPadK), BQuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -115,7 +115,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
// then by vector width to get an approximate number of vector loads.
|
||||
constexpr index_t BQload_inst = ck_tile::integer_divide_ceil(
|
||||
ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType),
|
||||
QuantGroupSize::kK * QuantGroupSize::kK),
|
||||
BQuantGroupSize::kK * BQuantGroupSize::kK),
|
||||
VectorLoadSize);
|
||||
|
||||
// ToDo: Hardcoded, need to change in future. How many instruction emit per iteration
|
||||
@@ -360,11 +360,11 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
BQBlockTile bq_block_tile, bq_block_tile_2;
|
||||
bq_block_tile = load_tile(bq_copy_dram_window);
|
||||
// move BQ to tile 1
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window,
|
||||
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
0});
|
||||
@@ -437,11 +437,11 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
bq_block_tile_2 = load_tile(bq_copy_dram_window);
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window,
|
||||
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
0});
|
||||
@@ -474,11 +474,11 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
bq_block_tile = load_tile(bq_copy_dram_window);
|
||||
if constexpr(PreshuffleQuant)
|
||||
if constexpr(BPreshuffleQuant)
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window,
|
||||
{((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
|
||||
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
|
||||
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
|
||||
: ck_tile::integer_least_multiple(n, kNPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{})),
|
||||
0});
|
||||
|
||||
@@ -33,7 +33,8 @@ inline std::string quant_type_to_string(QuantType quant_type)
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
bool PreshuffleQuant_,
|
||||
bool APreshuffleQuant_,
|
||||
bool BPreshuffleQuant_,
|
||||
bool PreshuffleB_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
@@ -71,8 +72,9 @@ struct TileGemmQuantTraits
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
|
||||
|
||||
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
|
||||
static constexpr bool PreshuffleB = PreshuffleB_;
|
||||
static constexpr bool APreshuffleQuant = APreshuffleQuant_;
|
||||
static constexpr bool BPreshuffleQuant = BPreshuffleQuant_;
|
||||
static constexpr bool PreshuffleB = PreshuffleB_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user