mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] add preshuffleB mode for ABQuant GEMM (#3495)
* [CK_TILE] add preshuffleB mode for ABQuant GEMM * fix precommit error * use template method call for cvt_scale_to_fp32 * fix precommit error * add test code * fix precommit error * switch abquant gemmconfig to default * Add changelog.md * fix precommit error * fix conflict
This commit is contained in:
@@ -69,4 +69,64 @@ void abquant_quantgrouped_instance_factory(
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"bf8",
|
||||
"abquant",
|
||||
"preshuffleb",
|
||||
"non-preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -74,9 +74,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>>>;
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>>>;
|
||||
|
||||
const ck_tile::index_t K_split =
|
||||
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
|
||||
@@ -145,26 +146,33 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
GemmConfig::Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>>>;
|
||||
using AQuantPipeline =
|
||||
std::conditional_t<GemmConfig::PreshuffleQuant,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>;
|
||||
|
||||
using BQuantPipeline = std::conditional_t<
|
||||
GemmConfig::PreshuffleB,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
|
||||
|
||||
using ABQuantPipeline =
|
||||
std::conditional_t<GemmConfig::DoubleSmemBuffer && GemmConfig::PreshuffleB,
|
||||
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
std::conditional_t<GemmConfig::PreshuffleQuant == true,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
GemmConfig::PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>>>;
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
AQuantPipeline,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::ABQuantGrouped,
|
||||
ABQuantPipeline,
|
||||
BQuantPipeline>>>;
|
||||
|
||||
constexpr bool TiledPermuteN =
|
||||
(BQuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN;
|
||||
@@ -908,8 +916,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if((QuantMode == ck_tile::QuantType::ABQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
if((QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>) &&
|
||||
GemmConfig::PreshuffleB)
|
||||
@@ -938,7 +945,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped) &&
|
||||
!GemmConfig::PreshuffleQuant)
|
||||
!GemmConfig::PreshuffleQuant && !GemmConfig::PreshuffleB)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user