mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_Tile] Adding support for preshuffleQuant in AB quant Block Scale Gemm (#3629)
* initial commit * preshuffleQuant support for ABQuant * fix mxfp4 to use correct QuantGroupSize * addressing review comments and seperated Preshufflequant for A and B * updated grouped gemm example for updated traits definition * fix for CI failure * updated grouped_gemm_abquant test for updated traits definition * updated grouped_gemm_abquant test for updated traits definition
This commit is contained in:
@@ -59,7 +59,8 @@ float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
false, // APreshuffleQuant
|
||||
false, // BPreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
@@ -202,7 +203,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
false, // APreshuffleQuant
|
||||
false, // BPreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
|
||||
@@ -44,7 +44,8 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
false, // APreshuffleQuant
|
||||
false, // BPreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
@@ -210,7 +211,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
false, // APreshuffleQuant
|
||||
false, // BPreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
|
||||
@@ -134,5 +134,35 @@ static auto _ = []() {
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"abquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfigPreshuffleBQuantPrefill<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
AQuantGroupSize,
|
||||
BQuantGroupSize,
|
||||
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
|
||||
};
|
||||
return 0;
|
||||
}();
|
||||
|
||||
@@ -80,7 +80,8 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
static constexpr bool PreshuffleQuant = false;
|
||||
static constexpr bool APreshuffleQuant = false;
|
||||
static constexpr bool BPreshuffleQuant = false;
|
||||
static constexpr bool PreshuffleB = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
@@ -157,7 +158,8 @@ struct GemmConfigPreshuffleQuantDecode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr bool APreshuffleQuant = true;
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -187,7 +189,7 @@ template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_PreshuffleBQuant_Decode
|
||||
: public GemmConfigPreshuffleB_BQuant_Decode<PrecType>
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -218,7 +220,7 @@ template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
|
||||
: public GemmConfigPreshuffleB_BQuant_Prefill<PrecType>
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -272,7 +274,7 @@ struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill<PrecType>
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
|
||||
@@ -33,7 +33,8 @@ template <typename GemmConfig,
|
||||
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr bool transpose_c = QuantMode == ck_tile::QuantType::ABQuantGrouped;
|
||||
constexpr bool transpose_c =
|
||||
GemmConfig::TransposeC; // QuantMode == ck_tile::QuantType::ABQuantGrouped;
|
||||
using ComputeDataType = std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant,
|
||||
typename TypeConfig::BDataType,
|
||||
@@ -50,14 +51,15 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
using GemmTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::PreshuffleQuant,
|
||||
GemmConfig::APreshuffleQuant,
|
||||
GemmConfig::BPreshuffleQuant,
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout, // for AQLayout
|
||||
BQLayout, // for BQLayout
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
transpose_c,
|
||||
GemmConfig::DoubleSmemBuffer>;
|
||||
|
||||
@@ -73,7 +75,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
GemmConfig::PreshuffleB == true,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true,
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::APreshuffleQuant == true,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
@@ -146,7 +148,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>>>;
|
||||
using AQuantPipeline =
|
||||
std::conditional_t<GemmConfig::PreshuffleQuant,
|
||||
std::conditional_t<GemmConfig::APreshuffleQuant,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>;
|
||||
|
||||
@@ -390,8 +392,10 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
std::cout << " Acc_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::AccDataType>::name
|
||||
<< " C_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::CDataType>::name
|
||||
<< " QuantMode = " << quant_type_to_string(QuantMode)
|
||||
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : "
|
||||
<< " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : "
|
||||
<< " APreshuffleQuant = " << (GemmConfig::APreshuffleQuant ? "true" : "false")
|
||||
<< " : "
|
||||
<< " BPreshuffleQuant = " << (GemmConfig::BPreshuffleQuant ? "true" : "false")
|
||||
<< " : " << " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : "
|
||||
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
@@ -536,21 +540,13 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
// Create BQ tensor with appropriate shape
|
||||
std::unique_ptr<ck_tile::HostTensor<BQDataType>> bq_tensor_ptr = nullptr;
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
|
||||
ck_tile::host_tensor_descriptor(1, 1, stride_BQ, is_row_major(bq_layout)));
|
||||
}
|
||||
|
||||
std::mt19937 gen(42);
|
||||
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
|
||||
@@ -870,7 +866,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant)
|
||||
{
|
||||
if constexpr(GemmConfig::PreshuffleQuant)
|
||||
if constexpr(GemmConfig::APreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
|
||||
ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / AQuantGroupSize::kK);
|
||||
@@ -929,7 +925,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
ck_tile::HostTensor<BQDataType> bq_permuted_host =
|
||||
ck_tile::bq_permuteN<GemmConfig>(*bq_tensor_ptr, BQuantGroupSize::kN);
|
||||
|
||||
if constexpr(GemmConfig::PreshuffleQuant)
|
||||
if constexpr(GemmConfig::BPreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<BQDataType> bq_shuffle_host = ck_tile::shuffle_bq(
|
||||
&bq_permuted_host, GemmConfig::K_Tile / BQuantGroupSize::kK);
|
||||
@@ -940,7 +936,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
bq_dev_buf_ptr->ToDevice(bq_permuted_host.data());
|
||||
}
|
||||
}
|
||||
else if constexpr(GemmConfig::PreshuffleQuant)
|
||||
else if constexpr(GemmConfig::BPreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<BQDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq(bq_tensor_ptr.get(), GemmConfig::K_Tile / BQuantGroupSize::kK);
|
||||
@@ -1121,7 +1117,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped) &&
|
||||
!GemmConfig::PreshuffleQuant && !GemmConfig::PreshuffleB)
|
||||
!GemmConfig::APreshuffleQuant && !GemmConfig::PreshuffleB)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
@@ -1142,7 +1138,8 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
arg_parser, Col{}, Row{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped && !GemmConfig::PreshuffleQuant)
|
||||
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped &&
|
||||
!GemmConfig::APreshuffleQuant)
|
||||
{
|
||||
if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user