fix: add dynamic selection of pipelines for aquant mode (#3282)

- Add conditional selection to use v3 pipeline when PreshuffleQuant is true
- Add static assertion in memory pipeline to prevent PreshuffleQuant usage
- Restore BaseBQuantGemmPipelineAgBgCrCompV3 for BQuant cases
- Update BaseGemmPipeline selection to handle all quant modes properly
This commit is contained in:
Aviral Goel
2025-11-26 10:58:09 +04:00
committed by GitHub
parent 8fa90025d0
commit 35a4b26af0
2 changed files with 11 additions and 4 deletions

View File

@@ -61,12 +61,16 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
GemmTraits,
ComputeDataType>;
// This example only supports BQuant (no AQuant)
// For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3
// Base pipeline selection based on quant mode and preshuffle settings
using BaseGemmPipeline = std::conditional_t<
GemmConfig::PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true,
ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::BaseAQuantGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>>>;
const ck_tile::index_t K_split =
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
@@ -125,7 +129,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<GemmConfig::PreshuffleQuant == true,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>,
std::conditional_t<GemmConfig::PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>;