diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 7becd28f77..2162141156 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -61,12 +61,16 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str GemmTraits, ComputeDataType>; - // This example only supports BQuant (no AQuant) - // For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3 + // Base pipeline selection based on quant mode and preshuffle settings using BaseGemmPipeline = std::conditional_t< GemmConfig::PreshuffleB == true, ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, - ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3>; + std::conditional_t< + QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true, + ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3>>>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -125,7 +129,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::GemmPipelineAgBgCrCompV3, std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped, - ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::AQuantGemmPipelineAgBgCrMem>, std::conditional_t, ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 9dea74c425..ca8598a03f 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -256,6 +256,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem; static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)"); + static_assert(!PreshuffleQuant, "Memory pipeline does not support PreshuffleQuant!"); static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}], "Aq block window has incorrect lengths for defined AqLayout!");