test(grouped_gemm): add unit tests for grouped_gemm bquant with preshuffleB true (#3119)

* add tensorwise quant in grouped gemm

* fix example issue

* update test cases

* format codes

* clang format

* use GTEST_FAIL

* add bquant to grouped_gemm

* add tensorwise quant in grouped gemm

* fix example issue

* update test cases

* format codes

* clang format

* use GTEST_FAIL

* fix a bug in test_grouped_gemm_util

* skip test when use wmma on grouped_quant kernel

* change cmake

* fix a bug in test_grouped_gemm_util

* skip test when use wmma on grouped_quant kernel

* change cmake

* tests(quant_grouped_gemm): add unit tests to cover bquant in grouped_gemm

* Update test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* feat: add bf8 support

* chore: remove unnecessary decltype usage

* chore: add default quant_mode to function signature as fallback

* fix: pass correct runtime pipeline params in grouped_gemm bquant kernel

Calculate has_hot_loop, num_loop, and tail_number on device side for each
GEMM problem instead of using default values. This fixes incorrect results
when different problems in the group have different K dimensions.

* chore: set default quant mode in function signature

* test: add additional test cases to cover edge case of no hotloop

* change code based on comments

* WIP: bquant preshuffle b compiles but gives numerical error

* feat(grouped_gemm_quant): bquant with preshuffleB support added to grouped_gemm example & kernel

* refactor: refactor code after merge commit

* chore: remove print statements

* test(grouped_gemm): split test cases by quant mode to reduce compilation time and add bquant-preshuffleB mode test cases

---------

Co-authored-by: kyle-256 <Kyle.Zhao@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Aviral Goel
2025-10-31 15:07:06 -04:00
committed by GitHub
parent a33d98f8e2
commit 8f1274d9b6
14 changed files with 425 additions and 74 deletions

View File

@@ -49,7 +49,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
GemmConfig::kPadN,
GemmConfig::kPadK,
false, // PreshuffleQuant
false, // PreshuffleB
GemmConfig::PreshuffleB, // PreshuffleB
ALayout,
BLayout,
CLayout,
@@ -58,7 +58,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
BQLayout,
GemmConfig::TransposeC,
GemmConfig::DoubleSmemBuffer,
true>;
true>; // Persistence
float ave_time{0};
@@ -86,10 +86,14 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
BDataType,
scheduler>>::type;
using GemmPipeline =
typename std::conditional<QuantMode == ck_tile::QuantType::BQuantGrouped,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>::type;
using GemmPipeline = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>,
std::conditional_t<GemmConfig::PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
@@ -141,5 +145,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
int main(int argc, char* argv[])
{
return !run_grouped_gemm_example<GemmConfigComputeV3_2>(argc, argv);
int result1 = !run_grouped_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv);
return result1;
}

View File

@@ -10,9 +10,6 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_BQUANT_COMPUTE_V3 2
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
@@ -31,6 +28,22 @@ constexpr ck_tile::index_t get_k_warp_tile()
#endif
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
template <typename DataType>
struct GemmTypeConfig;
@@ -67,8 +80,9 @@ struct GemmConfigBase
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool PreshuffleB = false;
};
template <typename PrecType>
@@ -85,10 +99,26 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
};
static constexpr bool DoubleSmemBuffer = false;
template <typename PrecType>
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
};
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
@@ -118,7 +148,8 @@ auto create_args(int argc, char* argv[])
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "8", "group count.")
.insert("kbatch", "1", "kbatch for SplitK")
.insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol");
.insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol")
.insert("init", "0", "0. Random, 2. One(s) (Constant)");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -163,6 +163,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
const int kbatch = arg_parser.get_int("kbatch");
const int init_method = arg_parser.get_int("init");
bool validate = arg_parser.get_bool("validate");
const ck_tile::index_t QuantGroupSize = 128;
@@ -203,6 +204,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(256 + 512 * i);
Ks.push_back(512 + 128 * i);
@@ -280,6 +282,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
stride_AQs[i] = 1; // Tensor quantization: tensor shape [1]
stride_BQs[i] = 1; // Tensor quantization: tensor shape [1]
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
stride_AQs[i] = 0; // No A quantization
stride_BQs[i] =
ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(bq_layout));
}
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
@@ -313,10 +321,20 @@ int run_grouped_gemm_example_with_layouts(int argc,
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
<< " aq: " << aq_tensors[i].mDesc << " bq: " << bq_tensors[i].mDesc << std::endl;
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
ck_tile::FillUniformDistribution<AQDataType>{-1.f, 1.f}(aq_tensors[i]);
ck_tile::FillUniformDistribution<BQDataType>{-1.f, 1.f}(bq_tensors[i]);
if(init_method == 2)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n_tensors[i]);
ck_tile::FillUniformDistribution<AQDataType>{1.f, 1.f}(aq_tensors[i]);
ck_tile::FillUniformDistribution<BQDataType>{1.f, 1.f}(bq_tensors[i]);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
ck_tile::FillUniformDistribution<AQDataType>{-1.f, 1.f}(aq_tensors[i]);
ck_tile::FillUniformDistribution<BQDataType>{-1.f, 1.f}(bq_tensors[i]);
}
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
a_m_k_tensors[i].get_element_space_size_in_bytes()));
@@ -329,8 +347,18 @@ int run_grouped_gemm_example_with_layouts(int argc,
bq_dev_buf.push_back(
std::make_unique<ck_tile::DeviceMem>(bq_tensors[i].get_element_space_size_in_bytes()));
if constexpr(GemmConfig::PreshuffleB && QuantMode == ck_tile::QuantType::BQuantGrouped)
{
ck_tile::HostTensor<BDataType> b_shuffle_host =
ck_tile::shuffle_b<GemmConfig>(b_k_n_tensors[i]);
b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data());
}
else
{
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
}
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
aq_dev_buf[i]->ToDevice(aq_tensors[i].data());
bq_dev_buf[i]->ToDevice(bq_tensors[i].data());
c_m_n_dev_buf[i]->SetZero();