mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
[CK_TILE] Add Bquant to Grouped Gemm (#3063)
* update test cases * format codes * use GTEST_FAIL * add bquant to grouped_gemm * fix a bug in test_grouped_gemm_util * skip test when use wmma on grouped_quant kernel * add tensorwise quant in grouped gemm * fix example issue * update test cases * format codes * fix a bug in test_grouped_gemm_util * tests(quant_grouped_gemm): add unit tests to cover bquant in grouped_gemm * Update test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * feat: add bf8 support * chore: remove unnecessary decltype usage * chore: add default quant_mode to function signature as fallback * fix: pass correct runtime pipeline params in grouped_gemm bquant kernel Calculate has_hot_loop, num_loop, and tail_number on device side for each GEMM problem instead of using default values. This fixes incorrect results when different problems in the group have different K dimensions. * chore: set default quant mode in function signature * test: add additional test cases to cover edge case of no hotloop * chore: clang formatting --------- Co-authored-by: kyle-256 <Kyle.Zhao@amd.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -472,6 +472,49 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
/// @brief Runtime pipeline dispatch operator for grouped GEMM kernels.
|
||||
///
|
||||
/// This operator is used by grouped GEMM kernels where pipeline parameters
|
||||
/// (has_hot_loop, num_loop, tail_number) are calculated on the device side
|
||||
/// at runtime, not on the host side during compilation. This is necessary
|
||||
/// because different GEMM problems in the group may have different K dimensions,
|
||||
/// requiring different pipeline configurations that cannot be determined at
|
||||
/// compile time.
|
||||
///
|
||||
/// @param a_dram_block_window_tmp Block window for A tensor in DRAM
|
||||
/// @param b_dram_block_window_tmp Block window for B tensor in DRAM
|
||||
/// @param bq_dram_block_window_tmp Block window for BQ (quantization scale) tensor in DRAM
|
||||
/// @param num_loop Number of main loop iterations (calculated on device)
|
||||
/// @param has_hot_loop Whether the pipeline has a hot loop (calculated on device)
|
||||
/// @param tail_number Type of tail handling required (calculated on device)
|
||||
/// @param p_smem Pointer to shared memory
|
||||
/// @return Accumulated result tile in registers
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename BQDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) {
|
||||
constexpr bool hot_loop = has_hot_loop_.value;
|
||||
constexpr auto tail_num = tail_number_.value;
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
bq_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user