mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
fix: add separate GemmConfig structs for AQuant, automatically select the correct one
This commit is contained in:
@@ -67,7 +67,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped;
|
||||
@@ -82,7 +81,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
transpose_c>,
|
||||
GemmConfig::TransposeC>,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
@@ -96,7 +95,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
transpose_c,
|
||||
GemmConfig::TransposeC,
|
||||
BDataType,
|
||||
scheduler>>;
|
||||
|
||||
@@ -161,7 +160,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
int result1 = run_grouped_gemm_example<GemmConfigComputeV3_2>(argc, argv); /* ||
|
||||
run_grouped_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv);*/
|
||||
int result1 = run_grouped_gemm_example(argc, argv);
|
||||
return result1;
|
||||
}
|
||||
|
||||
@@ -102,6 +102,24 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfig_Aquant : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool TransposeC = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
|
||||
{
|
||||
@@ -118,10 +136,42 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
template <ck_tile::QuantType QuantMode>
|
||||
struct GemmQuantConfig;
|
||||
|
||||
template <>
|
||||
struct GemmQuantConfig<ck_tile::QuantType::TensorQuant>
|
||||
{
|
||||
template <typename PrecType>
|
||||
using GemmConfig = GemmConfigComputeV3_2<PrecType>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantConfig<ck_tile::QuantType::RowColQuant>
|
||||
{
|
||||
template <typename PrecType>
|
||||
using GemmConfig = GemmConfigComputeV3_2<PrecType>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantConfig<ck_tile::QuantType::AQuantGrouped>
|
||||
{
|
||||
template <typename PrecType>
|
||||
using GemmConfig = GemmConfig_Aquant<PrecType>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantConfig<ck_tile::QuantType::BQuantGrouped>
|
||||
{
|
||||
template <typename PrecType>
|
||||
using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill<PrecType>;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
|
||||
@@ -533,12 +533,13 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename PrecType, ck_tile::QuantType QuantMode>
|
||||
template <typename PrecType, ck_tile::QuantType QuantMode>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Types = GemmTypeConfig<PrecType>;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType>;
|
||||
using Types = GemmTypeConfig<PrecType>;
|
||||
// Specific type aliases for easy access
|
||||
using ADataType = typename Types::ADataType;
|
||||
using BDataType = typename Types::BDataType;
|
||||
@@ -567,7 +568,6 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
int run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -585,30 +585,22 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::QuantType::TensorQuant>(
|
||||
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "rowcol")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::QuantType::AQuantGrouped>(
|
||||
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
@@ -620,30 +612,22 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::QuantType::TensorQuant>(
|
||||
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "rowcol")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::QuantType::AQuantGrouped>(
|
||||
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user