Merge commit '35a4b26af0088ca0d634b57055a4143fdb9f2e2d' into develop

This commit is contained in:
assistant-librarian[bot]
2025-11-26 07:13:26 +00:00
parent a86762f0f9
commit 283383c61c
2 changed files with 11 additions and 4 deletions

View File

@@ -61,12 +61,16 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
GemmTraits,
ComputeDataType>;
// This example only supports BQuant (no AQuant)
// For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3
// Base pipeline selection based on quant mode and preshuffle settings
using BaseGemmPipeline = std::conditional_t<
GemmConfig::PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true,
ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::BaseAQuantGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>>>;
const ck_tile::index_t K_split =
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
@@ -125,7 +129,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<GemmConfig::PreshuffleQuant == true,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>,
std::conditional_t<GemmConfig::PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>;

View File

@@ -256,6 +256,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem<Prob
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)");
static_assert(!PreshuffleQuant, "Memory pipeline does not support PreshuffleQuant!");
static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
"Aq block window has incorrect lengths for defined AqLayout!");