From 658fb530ab4844376d2d9db14f64bf03b6e8a893 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Fri, 31 Oct 2025 15:07:06 -0400 Subject: [PATCH] test(grouped_gemm): add unit tests for grouped_gemm bquant with preshuffleB true (#3119) * add tensorwise quant in grouped gemm * fix example issue * update test cases * format codes * clang format * use GTEST_FAIL * add bquant to grouped_gemm * add tensorwise quant in grouped gemm * fix example issue * update test cases * format codes * clang format * use GTEST_FAIL * fix a bug in test_grouped_gemm_util * skip test when use wmma on grouped_quant kernel * change cmake * fix a bug in test_grouped_gemm_util * skip test when use wmma on grouped_quant kernel * change cmake * tests(quant_grouped_gemm): add unit tests to cover bquant in grouped_gemm * Update test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * feat: add bf8 support * chore: remove unnecessary decltype usage * chore: add default quant_mode to function signature as fallback * fix: pass correct runtime pipeline params in grouped_gemm bquant kernel Calculate has_hot_loop, num_loop, and tail_number on device side for each GEMM problem instead of using default values. This fixes incorrect results when different problems in the group have different K dimensions. * chore: set default quant mode in function signature * test: add additional test cases to cover edge case of no hotloop * change code based on comments * WIP: bquant preshuffle b compiles but gives numerical error * feat(grouped_gemm_quant): bquant with preshuffleB support added to grouped_gemm example & kernel * refactor: refactor code after merge commit * chore: remove print statements * test(grouped_gemm): split test cases by quant mode to reduce compilation time and add bquant-preshuffleB mode test cases --------- Co-authored-by: kyle-256 Co-authored-by: ThomasNing Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> [ROCm/composable_kernel commit: 8f1274d9b655c2584b3643acac07ef813f31238e] --- .../17_grouped_gemm/quant_grouped_gemm.cpp | 19 ++-- .../17_grouped_gemm/quant_grouped_gemm.hpp | 45 +++++++-- .../quant_run_grouped_gemm_example.inc | 38 +++++++- .../wp_pipeline_agmem_bgmem_creg_v2.hpp | 2 +- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 4 + .../kernel/grouped_gemm_quant_kernel.hpp | 83 +++++++++++++++- .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 27 ++++++ .../ck_tile/grouped_gemm_quant/CMakeLists.txt | 11 ++- .../test_grouped_gemm_quant.cpp | 40 ++++---- .../test_grouped_gemm_quant_bquant.cpp | 33 +++++++ .../test_grouped_gemm_quant_rowcol.cpp | 35 +++++++ .../test_grouped_gemm_quant_tensor.cpp | 35 +++++++ .../test_grouped_gemm_quant_ut_cases.inc | 30 +++++- .../test_grouped_gemm_util_quant.hpp | 97 +++++++++++++------ 14 files changed, 425 insertions(+), 74 deletions(-) create mode 100644 test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp create mode 100644 test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp create mode 100644 test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index 1a913fcfc1..59ff086dca 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -49,7 +49,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, GemmConfig::kPadN, GemmConfig::kPadK, false, // PreshuffleQuant - false, // PreshuffleB + GemmConfig::PreshuffleB, // PreshuffleB ALayout, BLayout, CLayout, @@ -58,7 +58,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, BQLayout, GemmConfig::TransposeC, GemmConfig::DoubleSmemBuffer, - true>; + true>; // Persistence float ave_time{0}; @@ -86,10 +86,14 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, BDataType, scheduler>>::type; - using GemmPipeline = - typename std::conditional, - ck_tile::GemmPipelineAgBgCrCompV3>::type; + using GemmPipeline = std::conditional_t< + QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant, + ck_tile::GemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem(argc, argv); + int result1 = !run_grouped_gemm_example(argc, argv); + return result1; } diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp index 2885ce54fd..5d86286b66 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp @@ -10,9 +10,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_BQUANT_COMPUTE_V3 2 - template constexpr ck_tile::index_t get_k_warp_tile() { @@ -31,6 +28,22 @@ constexpr ck_tile::index_t get_k_warp_tile() #endif } +template +constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile() +{ +#if defined(CK_GFX950_SUPPORT) + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 64; + else + return sizeof(PrecType) == 2 ? 32 : 128; +#else + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 32; + else + return sizeof(PrecType) == 2 ? 32 : 64; +#endif +} + template struct GemmTypeConfig; @@ -67,8 +80,9 @@ struct GemmConfigBase static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool DoubleSmemBuffer = false; + static constexpr bool PreshuffleB = false; }; template @@ -85,10 +99,26 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); +}; - static constexpr bool DoubleSmemBuffer = false; +template +struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); - static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + get_k_from_preshuffled_warp_tile(); + + static constexpr bool PreshuffleB = true; + static constexpr bool DoubleSmemBuffer = true; }; using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; @@ -118,7 +148,8 @@ auto create_args(int argc, char* argv[]) .insert("repeat", "100", "number of iterations to benchmark the kernel.") .insert("group_count", "8", "group count.") .insert("kbatch", "1", "kbatch for SplitK") - .insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol"); + .insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol") + .insert("init", "0", "0. Random, 2. One(s) (Constant)"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc index 152df38bff..e71a6b8d30 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc @@ -163,6 +163,7 @@ int run_grouped_gemm_example_with_layouts(int argc, const int repeat = arg_parser.get_int("repeat"); const int warmup = arg_parser.get_int("warmup"); const int kbatch = arg_parser.get_int("kbatch"); + const int init_method = arg_parser.get_int("init"); bool validate = arg_parser.get_bool("validate"); const ck_tile::index_t QuantGroupSize = 128; @@ -203,6 +204,7 @@ int run_grouped_gemm_example_with_layouts(int argc, for(int i = 0; i < group_count; i++) { + Ms.push_back(256 + 256 * i); Ns.push_back(256 + 512 * i); Ks.push_back(512 + 128 * i); @@ -280,6 +282,12 @@ int run_grouped_gemm_example_with_layouts(int argc, stride_AQs[i] = 1; // Tensor quantization: tensor shape [1] stride_BQs[i] = 1; // Tensor quantization: tensor shape [1] } + else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + stride_AQs[i] = 0; // No A quantization + stride_BQs[i] = + ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(bq_layout)); + } a_m_k_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout)))); @@ -313,10 +321,20 @@ int run_grouped_gemm_example_with_layouts(int argc, << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc << " aq: " << aq_tensors[i].mDesc << " bq: " << bq_tensors[i].mDesc << std::endl; - ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); - ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); - ck_tile::FillUniformDistribution{-1.f, 1.f}(aq_tensors[i]); - ck_tile::FillUniformDistribution{-1.f, 1.f}(bq_tensors[i]); + if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{1.f, 1.f}(aq_tensors[i]); + ck_tile::FillUniformDistribution{1.f, 1.f}(bq_tensors[i]); + } + else + { + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(aq_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(bq_tensors[i]); + } a_m_k_dev_buf.push_back(std::make_unique( a_m_k_tensors[i].get_element_space_size_in_bytes())); @@ -329,8 +347,18 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_dev_buf.push_back( std::make_unique(bq_tensors[i].get_element_space_size_in_bytes())); + if constexpr(GemmConfig::PreshuffleB && QuantMode == ck_tile::QuantType::BQuantGrouped) + { + ck_tile::HostTensor b_shuffle_host = + ck_tile::shuffle_b(b_k_n_tensors[i]); + b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data()); + } + else + { + b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); + } + a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); - b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); aq_dev_buf[i]->ToDevice(aq_tensors[i].data()); bq_dev_buf[i]->ToDevice(bq_tensors[i].data()); c_m_n_dev_buf[i]->SetZero(); diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index 670f4b0575..87f6c753b4 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -20,7 +20,7 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2 CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } - CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) + CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; } diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 6f049a20a7..0afa70c99d 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -483,6 +483,7 @@ struct QuantGemmKernel const QuantGemmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset) { + static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) @@ -790,6 +791,7 @@ struct QuantGemmKernel }(); if constexpr(PreshuffleB) { + return make_tuple(a_pad_view, aq_pad_view, b_flat_view, bq_pad_view, c_pad_view); } else @@ -802,6 +804,7 @@ struct QuantGemmKernel CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) { + const auto& a_pad_view = views.at(I0); const auto& aq_pad_view = views.at(I1); const auto& b_pad_view = views.at(I2); @@ -867,6 +870,7 @@ struct QuantGemmKernel const auto& b_block_window = [&]() { if constexpr(PreshuffleB) { + return make_tile_window( b_pad_view, make_tuple(number{}, diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index eee4017c12..75ac1ca6ab 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -317,13 +317,88 @@ struct QuantGroupedGemmKernel const BQDataType* bq_ptr = static_cast(kargs.bq_ptr); CDataType* c_ptr = static_cast(kargs.c_ptr); - static_assert(GemmPipeline::DoubleSmemBuffer == false, - "DoubleSmemBuffer needs to be false"); // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; - RunGemmWithPipelineSelection( - a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + // Only for BQuantGrouped DoubleSmemBuffer is supported + if constexpr(GemmPipeline::DoubleSmemBuffer == true && + kQuantType == QuantType::BQuantGrouped) + { + + __shared__ char smem_ptr_1[GetSmemSize()]; + RunGemmWithPipelineSelection2LDS(a_ptr, + b_ptr, + aq_ptr, + bq_ptr, + c_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + else + { + + RunGemmWithPipelineSelection(a_ptr, + b_ptr, + aq_ptr, + bq_ptr, + c_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + + template + CK_TILE_DEVICE static void + RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr, + const BDataType* b_ptr, + const AQDataType* aq_ptr, + const BQDataType* bq_ptr, + CDataType* c_ptr, + void* smem_ptr_0, + void* smem_ptr_1, + const QuantGroupedGemmKernelArgs& kargs, + const typename Base::SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + static_assert(kQuantType == QuantType::BQuantGrouped, "kQuantType must be BQuantGrouped"); + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + Base::template MakeGemmTensorViews( + a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset); + + const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = + Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(Base::I0); + const auto& b_block_window = gemm_tile_windows.at(Base::I2); + + const auto& bq_block_window = gemm_tile_windows.at(Base::I3); + const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, + b_block_window, + bq_block_window, + num_loop, + tail_num, + smem_ptr_0, + smem_ptr_1); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(Base::I4); + + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); } /** diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index 196f47badb..d0427786fd 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -458,6 +458,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV void* p_smem_ping, void* p_smem_pong) const { + return operator()( a_dram_block_window_tmp, [](const ADataType& a) { return a; }, @@ -467,5 +468,31 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV p_smem_ping, p_smem_pong); } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + TailNumber tail_number, + void* p_smem_ping, + void* p_smem_pong) const + { + const auto RunPipeline = [&](auto bool_val, auto tail_num_) { + (void)bool_val; // Suppress unused parameter warning + constexpr auto tail_num = tail_num_.value; + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_flat_dram_block_window_tmp, + bq_dram_block_window_tmp, + num_loop, + p_smem_ping, + p_smem_pong); + }; + return Base::TailHandler(RunPipeline, true, tail_number); + } }; } // namespace ck_tile diff --git a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt index fddd8b69b2..3f32413f59 100644 --- a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt @@ -4,7 +4,14 @@ if(CK_USE_OCP_FP8) endif() if(GPU_TARGETS MATCHES "gfx94|gfx95") - add_gtest_executable(test_ck_tile_grouped_gemm_quant test_grouped_gemm_quant.cpp) - target_compile_options(test_ck_tile_grouped_gemm_quant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + # Split into three separate test executables for faster parallel compilation + add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp) + target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp) + target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp) + target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp index 599a669c0a..3ed82affc0 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp @@ -22,26 +22,28 @@ using BQuant = std::integral_constant, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>, + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>, - std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>, - std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>, - std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>, - std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>, - std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>, - std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant> + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, + std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, + std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, + std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, + std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, + std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, + std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp new file mode 100644 index 0000000000..8ac4d73cb4 --- /dev/null +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_util_quant.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; +using BQuant = std::integral_constant; + +// clang-format off +using KernelTypes_BQuant = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_BQuant, KernelTypes_BQuant); + +#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_BQuant +#include "test_grouped_gemm_quant_ut_cases.inc" +#undef TEST_CLASS_NAME diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp new file mode 100644 index 0000000000..be21cc3db5 --- /dev/null +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_util_quant.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; +using RowColQuant = std::integral_constant; + +// clang-format off +using KernelTypes_RowCol = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_RowCol, KernelTypes_RowCol); + +#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_RowCol +#include "test_grouped_gemm_quant_ut_cases.inc" +#undef TEST_CLASS_NAME diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp new file mode 100644 index 0000000000..eac4ba3d5c --- /dev/null +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_util_quant.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; +using TensorQuant = std::integral_constant; + +// clang-format off +using KernelTypes_Tensor = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_Tensor, KernelTypes_Tensor); + +#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_Tensor +#include "test_grouped_gemm_quant_ut_cases.inc" +#undef TEST_CLASS_NAME diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_ut_cases.inc b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_ut_cases.inc index 0b522f82f3..0cb9efc921 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_ut_cases.inc +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_ut_cases.inc @@ -1,6 +1,6 @@ #pragma once -TYPED_TEST(TestCkTileGroupedGemmQuant, Basic) +TYPED_TEST(TEST_CLASS_NAME, Basic) { const int group_count = 8; std::vector Ms; @@ -29,7 +29,7 @@ TYPED_TEST(TestCkTileGroupedGemmQuant, Basic) // No Hot Loop Test Case, this is to test the correctness of the kernel when there is no hot loop // Using 256x256x128 to match the test kernel's tile size (M_Tile=256, N_Tile=256, K_Tile=128) -TYPED_TEST(TestCkTileGroupedGemmQuant, SmallUniform) // +TYPED_TEST(TEST_CLASS_NAME, SmallUniform) // { const int group_count = 2; std::vector Ms; @@ -55,3 +55,29 @@ TYPED_TEST(TestCkTileGroupedGemmQuant, SmallUniform) // this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count); } +TYPED_TEST(TEST_CLASS_NAME, OddTail) // +{ + const int group_count = 2; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + std::vector stride_AQs; + std::vector stride_BQs; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256); + Ns.push_back(256); + Ks.push_back(128); + + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count); +} diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp index 5e9956ab5b..5d8534cc11 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp @@ -17,23 +17,40 @@ template class TestCkTileGroupedGemmQuant : public ::testing::Test { protected: - using ALayout = std::tuple_element_t<0, Tuple>; - using BLayout = std::tuple_element_t<1, Tuple>; - using CLayout = std::tuple_element_t<2, Tuple>; - using ADataType = std::tuple_element_t<3, Tuple>; - using AQDataType = std::tuple_element_t<4, Tuple>; - using BDataType = std::tuple_element_t<5, Tuple>; - using BQDataType = std::tuple_element_t<6, Tuple>; - using AccDataType = std::tuple_element_t<7, Tuple>; - using CDataType = std::tuple_element_t<8, Tuple>; - static constexpr auto QuantType = std::tuple_element_t<9, Tuple>::value; - using DsLayout = ck_tile::tuple<>; - using DsDataType = ck_tile::tuple<>; - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - using AQLayout = Row; - using BQLayout = Col; - static constexpr bool Persistent = true; + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using AQDataType = std::tuple_element_t<4, Tuple>; + using BDataType = std::tuple_element_t<5, Tuple>; + using BQDataType = std::tuple_element_t<6, Tuple>; + using AccDataType = std::tuple_element_t<7, Tuple>; + using CDataType = std::tuple_element_t<8, Tuple>; + static constexpr auto QuantType = std::tuple_element_t<9, Tuple>::value; + using DsLayout = ck_tile::tuple<>; + using DsDataType = ck_tile::tuple<>; + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using AQLayout = Row; + using BQLayout = Col; + static constexpr bool Persistent = true; + static constexpr bool PreshuffleB = std::tuple_element_t<10, Tuple>::value; + + template + static constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile() + { +#if defined(CK_GFX950_SUPPORT) + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 64; + else + return sizeof(PrecType) == 2 ? 32 : 128; +#else + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 32; + else + return sizeof(PrecType) == 2 ? 32 : 64; +#endif + } struct GroupedGemKernelParam_Mfma { @@ -52,7 +69,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test static const ck_tile::index_t M_Warp_Tile = 32; static const ck_tile::index_t N_Warp_Tile = 32; - static const ck_tile::index_t K_Warp_Tile = 16; + static const ck_tile::index_t K_Warp_Tile = + TestCkTileGroupedGemmQuant::template get_k_from_preshuffled_warp_tile(); }; using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; @@ -66,8 +85,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test const ck_tile::index_t num_groups, void* kargs_ptr) { - constexpr bool TransposeC = false; - constexpr bool DoubleSmemBuffer = false; + constexpr bool TransposeC = false; + constexpr bool DoubleSmemBuffer = + PreshuffleB; // currently DoubleSmemBuffer is only supported for preshuffled B constexpr int kBlockPerCu = 1; constexpr ck_tile::index_t TileParitionerGroupNum = 8; @@ -90,7 +110,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test GroupedGemKernelParam::kPadN, GroupedGemKernelParam::kPadK, false, - false, + PreshuffleB, ALayout, BLayout, CLayout, @@ -126,11 +146,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test BDataType, scheduler>>::type; - using GemmPipeline = typename std::conditional< - QuantType == ck_tile::QuantType::BQuantGrouped, - ck_tile::BQuantGemmPipelineAgBgCrCompV3, - ck_tile::GemmPipelineAgBgCrCompV3>::type; - + using GemmPipeline = std::conditional_t< + QuantType == ck_tile::QuantType::RowColQuant || + QuantType == ck_tile::QuantType::TensorQuant, + ck_tile::GemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblemToDevice(a_m_k_tensors[i].data()); - b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); + + if constexpr(PreshuffleB && QuantType == ck_tile::QuantType::BQuantGrouped) + { + auto b_shuffle_host = + ck_tile::shuffle_b(b_k_n_tensors[i]); + b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data()); + } + else + { + b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); + } + aq_dev_buf[i]->ToDevice(aq_tensors[i].data()); bq_dev_buf[i]->ToDevice(bq_tensors[i].data()); c_m_n_dev_buf[i]->SetZero(); @@ -485,3 +518,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test EXPECT_TRUE(pass); } }; + +// Aliases for split test files +template +using TestCkTileGroupedGemmQuant_RowCol = TestCkTileGroupedGemmQuant; + +template +using TestCkTileGroupedGemmQuant_Tensor = TestCkTileGroupedGemmQuant; + +template +using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant;