test(grouped_gemm): add unit tests for grouped_gemm bquant with preshuffleB true (#3119)

* add tensorwise quant in grouped gemm

* fix example issue

* update test cases

* format codes

* clang format

* use GTEST_FAIL

* add bquant to grouped_gemm

* add tensorwise quant in grouped gemm

* fix example issue

* update test cases

* format codes

* clang format

* use GTEST_FAIL

* fix a bug in test_grouped_gemm_util

* skip test when use wmma on grouped_quant kernel

* change cmake

* fix a bug in test_grouped_gemm_util

* skip test when use wmma on grouped_quant kernel

* change cmake

* tests(quant_grouped_gemm): add unit tests to cover bquant in grouped_gemm

* Update test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* feat: add bf8 support

* chore: remove unnecessary decltype usage

* chore: add default quant_mode to function signature as fallback

* fix: pass correct runtime pipeline params in grouped_gemm bquant kernel

Calculate has_hot_loop, num_loop, and tail_number on device side for each
GEMM problem instead of using default values. This fixes incorrect results
when different problems in the group have different K dimensions.

* chore: set default quant mode in function signature

* test: add additional test cases to cover edge case of no hotloop

* change code based on comments

* WIP: bquant preshuffle b compiles but gives numerical error

* feat(grouped_gemm_quant): bquant with preshuffleB support added to grouped_gemm example & kernel

* refactor: refactor code after merge commit

* chore: remove print statements

* test(grouped_gemm): split test cases by quant mode to reduce compilation time and add bquant-preshuffleB mode test cases

---------

Co-authored-by: kyle-256 <Kyle.Zhao@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Aviral Goel
2025-10-31 15:07:06 -04:00
committed by GitHub
parent a33d98f8e2
commit 8f1274d9b6
14 changed files with 425 additions and 74 deletions

View File

@@ -483,6 +483,7 @@ struct QuantGemmKernel
const QuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
@@ -790,6 +791,7 @@ struct QuantGemmKernel
}();
if constexpr(PreshuffleB)
{
return make_tuple(a_pad_view, aq_pad_view, b_flat_view, bq_pad_view, c_pad_view);
}
else
@@ -802,6 +804,7 @@ struct QuantGemmKernel
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& aq_pad_view = views.at(I1);
const auto& b_pad_view = views.at(I2);
@@ -867,6 +870,7 @@ struct QuantGemmKernel
const auto& b_block_window = [&]() {
if constexpr(PreshuffleB)
{
return make_tile_window(
b_pad_view,
make_tuple(number<GemmPipeline::flatNPerWarp>{},

View File

@@ -317,13 +317,88 @@ struct QuantGroupedGemmKernel
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
static_assert(GemmPipeline::DoubleSmemBuffer == false,
"DoubleSmemBuffer needs to be false");
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
RunGemmWithPipelineSelection(
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
// Only for BQuantGrouped DoubleSmemBuffer is supported
if constexpr(GemmPipeline::DoubleSmemBuffer == true &&
kQuantType == QuantType::BQuantGrouped)
{
__shared__ char smem_ptr_1[GetSmemSize()];
RunGemmWithPipelineSelection2LDS(a_ptr,
b_ptr,
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
RunGemmWithPipelineSelection(a_ptr,
b_ptr,
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void
RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr,
const BDataType* b_ptr,
const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
void* smem_ptr_1,
const QuantGroupedGemmKernelArgs& kargs,
const typename Base::SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
static_assert(kQuantType == QuantType::BQuantGrouped, "kQuantType must be BQuantGrouped");
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
const auto& b_block_window = gemm_tile_windows.at(Base::I2);
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
b_block_window,
bq_block_window,
num_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I4);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
/**

View File

@@ -458,6 +458,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
void* p_smem_ping,
void* p_smem_pong) const
{
return operator()<TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
@@ -467,5 +468,31 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
p_smem_ping,
p_smem_pong);
}
template <typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename BQDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t num_loop,
TailNumber tail_number,
void* p_smem_ping,
void* p_smem_pong) const
{
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
(void)bool_val; // Suppress unused parameter warning
constexpr auto tail_num = tail_num_.value;
return operator()<tail_num>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_flat_dram_block_window_tmp,
bq_dram_block_window_tmp,
num_loop,
p_smem_ping,
p_smem_pong);
};
return Base::TailHandler(RunPipeline, true, tail_number);
}
};
} // namespace ck_tile