mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-05 20:55:59 +00:00
refactor: cleaned up grouped gemm quant example a bit by reusing pipeline selection logic
This commit is contained in:
@@ -40,9 +40,6 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
@@ -74,16 +71,9 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = std::conditional_t<
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
|
||||
std::conditional_t<
|
||||
GemmConfig::PreshuffleB == true,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
|
||||
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
|
||||
using BaseGemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
@@ -100,6 +90,8 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
|
||||
|
||||
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped;
|
||||
using QuantGemmProblem = std::conditional_t<
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
@@ -138,15 +130,9 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
std::conditional_t<GemmConfig::PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>;
|
||||
using GemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
@@ -288,15 +274,9 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
BDataType,
|
||||
scheduler>>;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
std::conditional_t<GemmConfig::PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>;
|
||||
using GemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
|
||||
@@ -83,7 +83,7 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool PreshuffleB = false;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool Persistent = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -149,6 +149,12 @@ struct GemmQuantConfig<ck_tile::QuantType::TensorQuant>
|
||||
{
|
||||
template <typename PrecType>
|
||||
using GemmConfig = GemmConfigComputeV3_2<PrecType>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -156,6 +162,12 @@ struct GemmQuantConfig<ck_tile::QuantType::RowColQuant>
|
||||
{
|
||||
template <typename PrecType>
|
||||
using GemmConfig = GemmConfigComputeV3_2<PrecType>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -163,6 +175,12 @@ struct GemmQuantConfig<ck_tile::QuantType::AQuantGrouped>
|
||||
{
|
||||
template <typename PrecType>
|
||||
using GemmConfig = GemmConfig_Aquant<PrecType>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -170,6 +188,17 @@ struct GemmQuantConfig<ck_tile::QuantType::BQuantGrouped>
|
||||
{
|
||||
template <typename PrecType>
|
||||
using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill<PrecType>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using GemmPipeline = std::conditional_t<PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<GemmProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<GemmProblem>>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using BaseGemmPipeline =
|
||||
std::conditional_t<PreshuffleB == true,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmProblem>,
|
||||
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmProblem>>;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
|
||||
|
||||
Reference in New Issue
Block a user