mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
GEMM Blockscale ABQuant Optimization (#3620)
* GEMM Blockscale ABQuant Optimization * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix precommit error * clean * Fix --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Ding, Yi <yi.ding@amd.com>
This commit is contained in:
@@ -4,7 +4,13 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigABQuantPrefill<T>;
|
||||
|
||||
template <typename T>
|
||||
using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill<T>;
|
||||
|
||||
// template <typename T>
|
||||
// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode<T>;
|
||||
|
||||
void abquant_quantgrouped_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
@@ -78,7 +84,7 @@ void abquant_quantgrouped_instance_factory(
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::fp8_t>,
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
@@ -93,7 +99,7 @@ void abquant_quantgrouped_instance_factory(
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::fp8_t>,
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
@@ -108,7 +114,7 @@ void abquant_quantgrouped_instance_factory(
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::bf8_t>,
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
@@ -123,7 +129,7 @@ void abquant_quantgrouped_instance_factory(
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB_BQuant_Prefill<ck_tile::bf8_t>,
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::bf8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
|
||||
@@ -192,6 +192,28 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleB_BQuant_Prefill<PrecType>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr bool kPadK = false;
|
||||
static constexpr bool TransposeC = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_ABQuant_Decode : public GemmConfigPreshuffleB_BQuant_Prefill<PrecType>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
static constexpr bool kPadK = false;
|
||||
static constexpr bool TransposeC = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigQuantPrefill : public GemmConfigBase
|
||||
{
|
||||
@@ -209,6 +231,13 @@ struct GemmConfigQuantPrefill : public GemmConfigBase
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill<PrecType>
|
||||
{
|
||||
static constexpr bool kPadK = false;
|
||||
static constexpr bool TransposeC = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
|
||||
{
|
||||
|
||||
@@ -33,6 +33,7 @@ template <typename GemmConfig,
|
||||
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr bool transpose_c = QuantMode == ck_tile::QuantType::ABQuantGrouped;
|
||||
using ComputeDataType = std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant,
|
||||
typename TypeConfig::BDataType,
|
||||
@@ -57,7 +58,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
QuantMode,
|
||||
AQLayout, // for AQLayout
|
||||
BQLayout, // for BQLayout
|
||||
false,
|
||||
transpose_c,
|
||||
GemmConfig::DoubleSmemBuffer>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<typename TypeConfig::ADataType,
|
||||
@@ -88,7 +89,6 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
// row-col and tensor quants use the regular pipeline, A/B/AB quants use their own
|
||||
using PipelineProblem = std::conditional_t<
|
||||
|
||||
Reference in New Issue
Block a user