[CK_TILE] Add Bquant to Grouped Gemm (#3063)

* update test cases

* format codes

* use GTEST_FAIL

* add bquant to grouped_gemm

* fix a bug in test_grouped_gemm_util

* skip test when use wmma on grouped_quant kernel

* add tensorwise quant in grouped gemm

* fix example issue

* update test cases

* format codes

* fix a bug in test_grouped_gemm_util

* tests(quant_grouped_gemm): add unit tests to cover bquant in grouped_gemm

* Update test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* feat: add bf8 support

* chore: remove unnecessary decltype usage

* chore: add default quant_mode to function signature as fallback

* fix: pass correct runtime pipeline params in grouped_gemm bquant kernel

Calculate has_hot_loop, num_loop, and tail_number on device side for each
GEMM problem instead of using default values. This fixes incorrect results
when different problems in the group have different K dimensions.

* chore: set default quant mode in function signature

* test: add additional test cases to cover edge case of no hotloop

* chore: clang formatting

---------

Co-authored-by: kyle-256 <Kyle.Zhao@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Aviral Goel
2025-10-28 10:20:24 -04:00
committed by GitHub
parent 1c17bae816
commit 4368fd9f57
8 changed files with 276 additions and 104 deletions

View File

@@ -375,30 +375,48 @@ struct QuantGroupedGemmKernel
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I4);
if constexpr(kQuantType == QuantType::RowColQuant)
if constexpr(kQuantType == QuantType::BQuantGrouped)
{
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(c_block_window)>(
c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
aq_block_window,
bq_block_window);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
b_block_window,
bq_block_window,
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0);
auto& c_block_window = gemm_tile_windows.at(Base::I4);
// Run Epilogue Pipeline
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
else if constexpr(kQuantType == QuantType::TensorQuant)
else
{
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I4);
if constexpr(kQuantType == QuantType::RowColQuant)
{
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
EpiloguePipeline{}(c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
aq_block_window,
bq_block_window);
}
else if constexpr(kQuantType == QuantType::TensorQuant)
{
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
}
}
}

View File

@@ -472,6 +472,49 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
num_loop,
p_smem);
}
/// @brief Runtime pipeline dispatch operator for grouped GEMM kernels.
///
/// This operator is used by grouped GEMM kernels where pipeline parameters
/// (has_hot_loop, num_loop, tail_number) are calculated on the device side
/// at runtime, not on the host side during compilation. This is necessary
/// because different GEMM problems in the group may have different K dimensions,
/// requiring different pipeline configurations that cannot be determined at
/// compile time.
///
/// @param a_dram_block_window_tmp Block window for A tensor in DRAM
/// @param b_dram_block_window_tmp Block window for B tensor in DRAM
/// @param bq_dram_block_window_tmp Block window for BQ (quantization scale) tensor in DRAM
/// @param num_loop Number of main loop iterations (calculated on device)
/// @param has_hot_loop Whether the pipeline has a hot loop (calculated on device)
/// @param tail_number Type of tail handling required (calculated on device)
/// @param p_smem Pointer to shared memory
/// @return Accumulated result tile in registers
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename BQDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* p_smem) const
{
const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) {
constexpr bool hot_loop = has_hot_loop_.value;
constexpr auto tail_num = tail_number_.value;
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
bq_dram_block_window_tmp,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
};
} // namespace ck_tile