GEMM Blockscale ABQuant Optimization (#3620)

* GEMM Blockscale ABQuant Optimization

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix precommit error

* clean

* Fix

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Ding, Yi <yi.ding@amd.com>
This commit is contained in:
kensclin
2026-01-23 01:39:38 +08:00
committed by GitHub
parent 9e049a32a1
commit 31a35ecab4
7 changed files with 161 additions and 51 deletions

View File

@@ -4,7 +4,13 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigQuantPrefill<T>;
using GemmConfig = GemmConfigABQuantPrefill<T>;
template <typename T>
using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill<T>;
// template <typename T>
// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode<T>;
void abquant_quantgrouped_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
@@ -78,7 +84,7 @@ void abquant_quantgrouped_instance_factory(
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::fp8_t>,
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
@@ -93,7 +99,7 @@ void abquant_quantgrouped_instance_factory(
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::fp8_t>,
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
@@ -108,7 +114,7 @@ void abquant_quantgrouped_instance_factory(
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::bf8_t>,
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::bf8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
@@ -123,7 +129,7 @@ void abquant_quantgrouped_instance_factory(
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::bf8_t>,
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::bf8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,

View File

@@ -192,6 +192,28 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
static constexpr bool PreshuffleQuant = true;
};
template <typename PrecType>
struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleB_BQuant_Prefill<PrecType>
{
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = true;
};
template <typename PrecType>
struct GemmConfigPreshuffleB_ABQuant_Decode : public GemmConfigPreshuffleB_BQuant_Prefill<PrecType>
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr bool kPadK = false;
static constexpr bool TransposeC = true;
};
template <typename PrecType>
struct GemmConfigQuantPrefill : public GemmConfigBase
{
@@ -209,6 +231,13 @@ struct GemmConfigQuantPrefill : public GemmConfigBase
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <typename PrecType>
struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill<PrecType>
{
static constexpr bool kPadK = false;
static constexpr bool TransposeC = true;
};
template <typename PrecType>
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
{

View File

@@ -33,6 +33,7 @@ template <typename GemmConfig,
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr bool transpose_c = QuantMode == ck_tile::QuantType::ABQuantGrouped;
using ComputeDataType = std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant,
typename TypeConfig::BDataType,
@@ -57,7 +58,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
QuantMode,
AQLayout, // for AQLayout
BQLayout, // for BQLayout
false,
transpose_c,
GemmConfig::DoubleSmemBuffer>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<typename TypeConfig::ADataType,
@@ -88,7 +89,6 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = false;
// row-col and tensor quants use the regular pipeline, A/B/AB quants use their own
using PipelineProblem = std::conditional_t<