refactor: cleaned up grouped gemm quant example a bit by reusing pipeline selection logic

This commit is contained in:
Erwin Terpstra
2025-12-01 14:27:16 +00:00
parent 0cb77e511b
commit 9a01e4ac28
2 changed files with 41 additions and 32 deletions

View File

@@ -40,9 +40,6 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::BQuantGrouped;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
@@ -74,16 +71,9 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = std::conditional_t<
UseGroupedQuant,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
std::conditional_t<
GemmConfig::PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
using BaseGemmPipeline =
GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
GemmConfig::PreshuffleB>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
@@ -100,6 +90,8 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::BQuantGrouped;
using QuantGemmProblem = std::conditional_t<
UseGroupedQuant,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
@@ -138,15 +130,9 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
has_hot_loop_v,
tail_number_v>>;
using GemmPipeline = std::conditional_t<
UseGroupedQuant,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
std::conditional_t<GemmConfig::PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>;
using GemmPipeline =
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
@@ -288,15 +274,9 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
BDataType,
scheduler>>;
using GemmPipeline = std::conditional_t<
UseGroupedQuant,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
std::conditional_t<GemmConfig::PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>;
using GemmPipeline =
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,

View File

@@ -83,7 +83,7 @@ struct GemmConfigBase
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool PreshuffleB = false;
static constexpr bool Persistent = true;
static constexpr bool Persistent = false;
};
template <typename PrecType>
@@ -149,6 +149,12 @@ struct GemmQuantConfig<ck_tile::QuantType::TensorQuant>
{
template <typename PrecType>
using GemmConfig = GemmConfigComputeV3_2<PrecType>;
template <typename GemmProblem, bool PreshuffleB = false>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<GemmProblem>;
template <typename GemmProblem, bool PreshuffleB = false>
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
};
template <>
@@ -156,6 +162,12 @@ struct GemmQuantConfig<ck_tile::QuantType::RowColQuant>
{
template <typename PrecType>
using GemmConfig = GemmConfigComputeV3_2<PrecType>;
template <typename GemmProblem, bool PreshuffleB = false>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<GemmProblem>;
template <typename GemmProblem, bool PreshuffleB = false>
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
};
template <>
@@ -163,6 +175,12 @@ struct GemmQuantConfig<ck_tile::QuantType::AQuantGrouped>
{
template <typename PrecType>
using GemmConfig = GemmConfig_Aquant<PrecType>;
template <typename GemmProblem, bool PreshuffleB = false>
using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<GemmProblem>;
template <typename GemmProblem, bool PreshuffleB = false>
using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3<GemmProblem>;
};
template <>
@@ -170,6 +188,17 @@ struct GemmQuantConfig<ck_tile::QuantType::BQuantGrouped>
{
template <typename PrecType>
using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill<PrecType>;
template <typename GemmProblem, bool PreshuffleB = false>
using GemmPipeline = std::conditional_t<PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<GemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<GemmProblem>>;
template <typename GemmProblem, bool PreshuffleB = false>
using BaseGemmPipeline =
std::conditional_t<PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmProblem>,
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmProblem>>;
};
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;