From 9a01e4ac28bb7e8ebeacf4c0c43c0fdedc65a541 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Mon, 1 Dec 2025 14:27:16 +0000 Subject: [PATCH] refactor: cleaned up grouped gemm quant example a bit by reusing pipeline selection logic --- .../17_grouped_gemm/quant_grouped_gemm.cpp | 42 +++++-------------- .../17_grouped_gemm/quant_grouped_gemm.hpp | 31 +++++++++++++- 2 files changed, 41 insertions(+), 32 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index fe36d2c2ce..d3b75ac72f 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -40,9 +40,6 @@ float grouped_gemm(const std::vector& gemm_descs, constexpr ck_tile::index_t TileParitionerGroupNum = 8; constexpr ck_tile::index_t TileParitionerM01 = 4; - constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::BQuantGrouped; - using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, ck_tile::sequence, @@ -74,16 +71,9 @@ float grouped_gemm(const std::vector& gemm_descs, using GemmPipelineProblem = ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = std::conditional_t< - UseGroupedQuant, - std::conditional_t< - QuantMode == ck_tile::QuantType::AQuantGrouped, - ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3, - std::conditional_t< - GemmConfig::PreshuffleB == true, - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, - ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3>>, - ck_tile::BaseGemmPipelineAgBgCrCompV3>; + using BaseGemmPipeline = + GemmQuantConfig::template BaseGemmPipeline; const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; @@ -100,6 +90,8 @@ float grouped_gemm(const std::vector& gemm_descs, constexpr auto scheduler = GemmConfig::Scheduler; constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::BQuantGrouped; using QuantGemmProblem = std::conditional_t< UseGroupedQuant, std::conditional_t& gemm_descs, has_hot_loop_v, tail_number_v>>; - using GemmPipeline = std::conditional_t< - UseGroupedQuant, - std::conditional_t< - QuantMode == ck_tile::QuantType::AQuantGrouped, - ck_tile::AQuantGemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, - ck_tile::GemmPipelineAgBgCrCompV3>; + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; - using GemmPipeline = std::conditional_t< - UseGroupedQuant, - std::conditional_t< - QuantMode == ck_tile::QuantType::AQuantGrouped, - ck_tile::AQuantGemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, - ck_tile::GemmPipelineAgBgCrCompV3>; + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem @@ -149,6 +149,12 @@ struct GemmQuantConfig { template using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; }; template <> @@ -156,6 +162,12 @@ struct GemmQuantConfig { template using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; }; template <> @@ -163,6 +175,12 @@ struct GemmQuantConfig { template using GemmConfig = GemmConfig_Aquant; + + template + using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3; }; template <> @@ -170,6 +188,17 @@ struct GemmQuantConfig { template using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill; + + template + using GemmPipeline = std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>; + + template + using BaseGemmPipeline = + std::conditional_t, + ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3>; }; using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;