fix: add separate GemmConfig structs for AQuant, automatically select the correct one

This commit is contained in:
Erwin Terpstra
2025-11-28 15:06:10 +00:00
parent 5d4a91a09b
commit f7409227fb
3 changed files with 66 additions and 34 deletions

View File

@@ -67,7 +67,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
constexpr bool transpose_c = false;
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::BQuantGrouped;
@@ -82,7 +81,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
transpose_c>,
GemmConfig::TransposeC>,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
@@ -96,7 +95,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
AccDataType,
GemmShape,
GemmUniversalTraits,
transpose_c,
GemmConfig::TransposeC,
BDataType,
scheduler>>;
@@ -161,7 +160,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
int main(int argc, char* argv[])
{
int result1 = run_grouped_gemm_example<GemmConfigComputeV3_2>(argc, argv); /* ||
run_grouped_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv);*/
int result1 = run_grouped_gemm_example(argc, argv);
return result1;
}

View File

@@ -102,6 +102,24 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <typename PrecType>
struct GemmConfig_Aquant : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool TransposeC = true;
};
template <typename PrecType>
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
{
@@ -118,10 +136,42 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool TransposeC = false;
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
};
template <ck_tile::QuantType QuantMode>
struct GemmQuantConfig;
template <>
struct GemmQuantConfig<ck_tile::QuantType::TensorQuant>
{
template <typename PrecType>
using GemmConfig = GemmConfigComputeV3_2<PrecType>;
};
template <>
struct GemmQuantConfig<ck_tile::QuantType::RowColQuant>
{
template <typename PrecType>
using GemmConfig = GemmConfigComputeV3_2<PrecType>;
};
template <>
struct GemmQuantConfig<ck_tile::QuantType::AQuantGrouped>
{
template <typename PrecType>
using GemmConfig = GemmConfig_Aquant<PrecType>;
};
template <>
struct GemmQuantConfig<ck_tile::QuantType::BQuantGrouped>
{
template <typename PrecType>
using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill<PrecType>;
};
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
auto create_args(int argc, char* argv[])

View File

@@ -533,12 +533,13 @@ int run_grouped_gemm_example_with_layouts(int argc,
return pass;
}
template <typename GemmConfig, typename PrecType, ck_tile::QuantType QuantMode>
template <typename PrecType, ck_tile::QuantType QuantMode>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Types = GemmTypeConfig<PrecType>;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType>;
using Types = GemmTypeConfig<PrecType>;
// Specific type aliases for easy access
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
@@ -567,7 +568,6 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
}
}
template <template <typename PrecType> typename GemmConfig>
int run_grouped_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -585,30 +585,22 @@ int run_grouped_gemm_example(int argc, char* argv[])
{
if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::QuantType::TensorQuant>(
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::QuantType::RowColQuant>(
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "aquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::QuantType::AQuantGrouped>(
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::QuantType::BQuantGrouped>(
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
@@ -620,30 +612,22 @@ int run_grouped_gemm_example(int argc, char* argv[])
{
if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::QuantType::TensorQuant>(
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::QuantType::RowColQuant>(
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "aquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::QuantType::AQuantGrouped>(
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::QuantType::BQuantGrouped>(
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else