[CK_TILE] Add Bquant to Grouped Gemm (#3063)

* update test cases

* format codes

* use GTEST_FAIL

* add bquant to grouped_gemm

* fix a bug in test_grouped_gemm_util

* skip test when use wmma on grouped_quant kernel

* add tensorwise quant in grouped gemm

* fix example issue

* update test cases

* format codes

* fix a bug in test_grouped_gemm_util

* tests(quant_grouped_gemm): add unit tests to cover bquant in grouped_gemm

* Update test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* feat: add bf8 support

* chore: remove unnecessary decltype usage

* chore: add default quant_mode to function signature as fallback

* fix: pass correct runtime pipeline params in grouped_gemm bquant kernel

Calculate has_hot_loop, num_loop, and tail_number on device side for each
GEMM problem instead of using default values. This fixes incorrect results
when different problems in the group have different K dimensions.

* chore: set default quant mode in function signature

* test: add additional test cases to cover edge case of no hotloop

* chore: clang formatting

---------

Co-authored-by: kyle-256 <Kyle.Zhao@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Aviral Goel
2025-10-28 10:20:24 -04:00
committed by GitHub
parent 1c17bae816
commit 4368fd9f57
8 changed files with 276 additions and 104 deletions

View File

@@ -29,7 +29,7 @@ template <typename GemmConfig,
typename BQDataType,
typename AccDataType,
typename CDataType,
ck_tile::QuantType QuantMode>
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr)
@@ -48,8 +48,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
false,
false,
false, // PreshuffleQuant
false, // PreshuffleB
ALayout,
BLayout,
CLayout,
@@ -67,18 +67,29 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
constexpr auto memory_operation = memory_operation_.value;
constexpr bool transpose_c = false;
using QuantGemmProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
transpose_c,
BDataType,
scheduler>;
using QuantGemmProblem = typename std::conditional<
QuantMode == ck_tile::QuantType::BQuantGrouped,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
128>, // QuantGroupSize
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
transpose_c,
BDataType,
scheduler>>::type;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<QuantGemmProblem>;
using GemmPipeline =
typename std::conditional<QuantMode == ck_tile::QuantType::BQuantGrouped,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>::type;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,

View File

@@ -11,6 +11,7 @@
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_BQUANT_COMPUTE_V3 2
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
@@ -41,6 +42,14 @@ struct GemmTypeConfig<ck_tile::fp8_t>
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
struct GemmConfigBase
{
@@ -77,24 +86,11 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int kBlockPerCu = 1;
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
auto create_args(int argc, char* argv[])
@@ -122,8 +118,7 @@ auto create_args(int argc, char* argv[])
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "8", "group count.")
.insert("kbatch", "1", "kbatch for SplitK")
.insert("quant_mode", "tensor", "Choose tensor (default), or rowcol");
;
.insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -43,8 +43,8 @@ template <typename GemmConfig,
typename BLayout,
typename BQLayout,
typename CLayout,
ck_tile::QuantType QuantMode,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(int n_warmup,
int n_repeat,
int group_count,
@@ -159,11 +159,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
return group_count != 0 && ((args.size() == static_cast<size_t>(group_count)) && ...);
};
const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
const int kbatch = arg_parser.get_int("kbatch");
bool validate = arg_parser.get_bool("validate");
const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
const int kbatch = arg_parser.get_int("kbatch");
bool validate = arg_parser.get_bool("validate");
const ck_tile::index_t QuantGroupSize = 128;
if(kbatch > 1 && validate && warmup + repeat > 1)
{
@@ -172,9 +173,11 @@ int run_grouped_gemm_example_with_layouts(int argc,
validate = false;
}
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
std::vector<ck_tile::index_t> AQs; // dimension of AQ tensor is calculated from A tensor
std::vector<ck_tile::index_t> BQs; // dimension of BQ tensor is calculated from B tensor
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
@@ -252,6 +255,15 @@ int run_grouped_gemm_example_with_layouts(int argc,
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
AQK = 0; // No A quantization
BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize
if(K % QuantGroupSize != 0)
{
throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode");
}
}
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
@@ -289,6 +301,13 @@ int run_grouped_gemm_example_with_layouts(int argc,
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout))));
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
ck_tile::host_tensor_descriptor(0, AQK, stride_AQs[i], is_row_major(aq_layout))));
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout))));
}
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
@@ -394,6 +413,17 @@ int run_grouped_gemm_example_with_layouts(int argc,
bq_tensors[i],
c_m_n_host_ref);
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
ck_tile::reference_gemm_quant<ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
false>(
a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
}
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
@@ -441,42 +471,6 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
QuantMode>(
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
}
else if(a_layout == "R" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
CDataType,
AccDataType,
QuantMode>(
argc, argv, Row{}, Row{}, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
CDataType,
AccDataType,
QuantMode>(
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
CDataType,
AccDataType,
QuantMode>(
argc, argv, Col{}, Col{}, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
@@ -513,6 +507,41 @@ int run_grouped_gemm_example(int argc, char* argv[])
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported quantization mode!");
}
}
if(data_type == "bf8")
{
if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported quantization mode!");