rename quant pipeline problem

This commit is contained in:
Sami Remes
2025-09-15 13:47:01 +00:00
parent fc4dbd8b7b
commit dc97be711d
3 changed files with 34 additions and 33 deletions

View File

@@ -65,15 +65,15 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
constexpr auto memory_operation = memory_operation_.value;
constexpr bool transpose_c = false;
using QuantGemmProblem = ck_tile::GemmRowColQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
transpose_c,
BDataType,
scheduler>;
using QuantGemmProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
transpose_c,
BDataType,
scheduler>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<QuantGemmProblem>;

View File

@@ -70,17 +70,17 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
using PipelineProblem = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant,
ck_tile::GemmRowColQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
typename TypeConfig::AccDataType,
typename TypeConfig::AccDataType,
GemmShape,
GemmTraits,
transpose_c,
ComputeDataType,
GemmConfig::Scheduler,
has_hot_loop_v,
tail_number_v>,
ck_tile::GemmRowColTensorQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
typename TypeConfig::AccDataType,
typename TypeConfig::AccDataType,
GemmShape,
GemmTraits,
transpose_c,
ComputeDataType,
GemmConfig::Scheduler,
has_hot_loop_v,
tail_number_v>,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::QDataType,

View File

@@ -168,17 +168,18 @@ template <typename ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
using GemmRowColQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
AccDataType_,
BDataType_,
AccDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
1, // no group size applicable
TransposeC_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
using GemmRowColTensorQuantPipelineProblem =
GemmQuantPipelineProblemBase<ADataType_,
AccDataType_,
BDataType_,
AccDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
1, // no group size applicable
TransposeC_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
} // namespace ck_tile