[CK_TILE] Add Bquant to Grouped Gemm (#3063)

* update test cases

* format codes

* use GTEST_FAIL

* add bquant to grouped_gemm

* fix a bug in test_grouped_gemm_util

* skip test when use wmma on grouped_quant kernel

* add tensorwise quant in grouped gemm

* fix example issue

* update test cases

* format codes

* fix a bug in test_grouped_gemm_util

* tests(quant_grouped_gemm): add unit tests to cover bquant in grouped_gemm

* Update test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* feat: add bf8 support

* chore: remove unnecessary decltype usage

* chore: add default quant_mode to function signature as fallback

* fix: pass correct runtime pipeline params in grouped_gemm bquant kernel

Calculate has_hot_loop, num_loop, and tail_number on device side for each
GEMM problem instead of using default values. This fixes incorrect results
when different problems in the group have different K dimensions.

* chore: set default quant mode in function signature

* test: add additional test cases to cover edge case of no hotloop

* chore: clang formatting

---------

Co-authored-by: kyle-256 <Kyle.Zhao@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Aviral Goel
2025-10-28 10:20:24 -04:00
committed by GitHub
parent 1c17bae816
commit 4368fd9f57
8 changed files with 276 additions and 104 deletions

View File

@@ -472,6 +472,49 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
num_loop,
p_smem);
}
/// @brief Runtime pipeline dispatch operator for grouped GEMM kernels.
///
/// This operator is used by grouped GEMM kernels where pipeline parameters
/// (has_hot_loop, num_loop, tail_number) are calculated on the device side
/// at runtime, not on the host side during compilation. This is necessary
/// because different GEMM problems in the group may have different K dimensions,
/// requiring different pipeline configurations that cannot be determined at
/// compile time.
///
/// @param a_dram_block_window_tmp Block window for A tensor in DRAM
/// @param b_dram_block_window_tmp Block window for B tensor in DRAM
/// @param bq_dram_block_window_tmp Block window for BQ (quantization scale) tensor in DRAM
/// @param num_loop Number of main loop iterations (calculated on device)
/// @param has_hot_loop Whether the pipeline has a hot loop (calculated on device)
/// @param tail_number Type of tail handling required (calculated on device)
/// @param p_smem Pointer to shared memory
/// @return Accumulated result tile in registers
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename BQDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* p_smem) const
{
const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) {
constexpr bool hot_loop = has_hot_loop_.value;
constexpr auto tail_num = tail_number_.value;
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
bq_dram_block_window_tmp,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
};
} // namespace ck_tile