test(grouped_gemm): add unit tests for grouped_gemm bquant with preshuffleB true (#3119)

* add tensorwise quant in grouped gemm

* fix example issue

* update test cases

* format codes

* clang format

* use GTEST_FAIL

* add bquant to grouped_gemm

* add tensorwise quant in grouped gemm

* fix example issue

* update test cases

* format codes

* clang format

* use GTEST_FAIL

* fix a bug in test_grouped_gemm_util

* skip test when use wmma on grouped_quant kernel

* change cmake

* fix a bug in test_grouped_gemm_util

* skip test when use wmma on grouped_quant kernel

* change cmake

* tests(quant_grouped_gemm): add unit tests to cover bquant in grouped_gemm

* Update test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* feat: add bf8 support

* chore: remove unnecessary decltype usage

* chore: add default quant_mode to function signature as fallback

* fix: pass correct runtime pipeline params in grouped_gemm bquant kernel

Calculate has_hot_loop, num_loop, and tail_number on device side for each
GEMM problem instead of using default values. This fixes incorrect results
when different problems in the group have different K dimensions.

* chore: set default quant mode in function signature

* test: add additional test cases to cover edge case of no hotloop

* change code based on comments

* WIP: bquant preshuffle b compiles but gives numerical error

* feat(grouped_gemm_quant): bquant with preshuffleB support added to grouped_gemm example & kernel

* refactor: refactor code after merge commit

* chore: remove print statements

* test(grouped_gemm): split test cases by quant mode to reduce compilation time and add bquant-preshuffleB mode test cases

---------

Co-authored-by: kyle-256 <Kyle.Zhao@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Aviral Goel
2025-10-31 15:07:06 -04:00
committed by GitHub
parent a33d98f8e2
commit 8f1274d9b6
14 changed files with 425 additions and 74 deletions

View File

@@ -4,7 +4,14 @@ if(CK_USE_OCP_FP8)
endif()
if(GPU_TARGETS MATCHES "gfx94|gfx95")
add_gtest_executable(test_ck_tile_grouped_gemm_quant test_grouped_gemm_quant.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
# Split into three separate test executables for faster parallel compilation
add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()

View File

@@ -22,26 +22,28 @@ using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantTyp
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>,
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant>
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True>
>;
// clang-format on

View File

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_grouped_gemm_util_quant.hpp"
using F16 = ck_tile::half_t;
using F32 = float;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using True = ck_tile::bool_constant<true>;
using False = ck_tile::bool_constant<false>;
using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
// clang-format off
using KernelTypes_BQuant = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_BQuant, KernelTypes_BQuant);
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_BQuant
#include "test_grouped_gemm_quant_ut_cases.inc"
#undef TEST_CLASS_NAME

View File

@@ -0,0 +1,35 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_grouped_gemm_util_quant.hpp"
using F16 = ck_tile::half_t;
using F32 = float;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using True = ck_tile::bool_constant<true>;
using False = ck_tile::bool_constant<false>;
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
// clang-format off
using KernelTypes_RowCol = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_RowCol, KernelTypes_RowCol);
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_RowCol
#include "test_grouped_gemm_quant_ut_cases.inc"
#undef TEST_CLASS_NAME

View File

@@ -0,0 +1,35 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_grouped_gemm_util_quant.hpp"
using F16 = ck_tile::half_t;
using F32 = float;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using True = ck_tile::bool_constant<true>;
using False = ck_tile::bool_constant<false>;
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
// clang-format off
using KernelTypes_Tensor = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_Tensor, KernelTypes_Tensor);
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_Tensor
#include "test_grouped_gemm_quant_ut_cases.inc"
#undef TEST_CLASS_NAME

View File

@@ -1,6 +1,6 @@
#pragma once
TYPED_TEST(TestCkTileGroupedGemmQuant, Basic)
TYPED_TEST(TEST_CLASS_NAME, Basic)
{
const int group_count = 8;
std::vector<int> Ms;
@@ -29,7 +29,7 @@ TYPED_TEST(TestCkTileGroupedGemmQuant, Basic)
// No Hot Loop Test Case, this is to test the correctness of the kernel when there is no hot loop
// Using 256x256x128 to match the test kernel's tile size (M_Tile=256, N_Tile=256, K_Tile=128)
TYPED_TEST(TestCkTileGroupedGemmQuant, SmallUniform) //
TYPED_TEST(TEST_CLASS_NAME, SmallUniform) //
{
const int group_count = 2;
std::vector<int> Ms;
@@ -55,3 +55,29 @@ TYPED_TEST(TestCkTileGroupedGemmQuant, SmallUniform) //
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count);
}
TYPED_TEST(TEST_CLASS_NAME, OddTail) //
{
const int group_count = 2;
std::vector<int> Ms;
std::vector<int> Ns;
std::vector<int> Ks;
std::vector<int> stride_As;
std::vector<int> stride_Bs;
std::vector<int> stride_Cs;
std::vector<int> stride_AQs;
std::vector<int> stride_BQs;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256);
Ns.push_back(256);
Ks.push_back(128);
stride_As.push_back(0);
stride_Bs.push_back(0);
stride_Cs.push_back(0);
stride_AQs.push_back(0);
stride_BQs.push_back(0);
}
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count);
}

View File

@@ -17,23 +17,40 @@ template <typename Tuple>
class TestCkTileGroupedGemmQuant : public ::testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using AQDataType = std::tuple_element_t<4, Tuple>;
using BDataType = std::tuple_element_t<5, Tuple>;
using BQDataType = std::tuple_element_t<6, Tuple>;
using AccDataType = std::tuple_element_t<7, Tuple>;
using CDataType = std::tuple_element_t<8, Tuple>;
static constexpr auto QuantType = std::tuple_element_t<9, Tuple>::value;
using DsLayout = ck_tile::tuple<>;
using DsDataType = ck_tile::tuple<>;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using AQLayout = Row;
using BQLayout = Col;
static constexpr bool Persistent = true;
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using AQDataType = std::tuple_element_t<4, Tuple>;
using BDataType = std::tuple_element_t<5, Tuple>;
using BQDataType = std::tuple_element_t<6, Tuple>;
using AccDataType = std::tuple_element_t<7, Tuple>;
using CDataType = std::tuple_element_t<8, Tuple>;
static constexpr auto QuantType = std::tuple_element_t<9, Tuple>::value;
using DsLayout = ck_tile::tuple<>;
using DsDataType = ck_tile::tuple<>;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using AQLayout = Row;
using BQLayout = Col;
static constexpr bool Persistent = true;
static constexpr bool PreshuffleB = std::tuple_element_t<10, Tuple>::value;
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
static constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
struct GroupedGemKernelParam_Mfma
{
@@ -52,7 +69,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
static const ck_tile::index_t M_Warp_Tile = 32;
static const ck_tile::index_t N_Warp_Tile = 32;
static const ck_tile::index_t K_Warp_Tile = 16;
static const ck_tile::index_t K_Warp_Tile =
TestCkTileGroupedGemmQuant::template get_k_from_preshuffled_warp_tile<BDataType,
M_Warp_Tile>();
};
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
@@ -66,8 +85,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
const ck_tile::index_t num_groups,
void* kargs_ptr)
{
constexpr bool TransposeC = false;
constexpr bool DoubleSmemBuffer = false;
constexpr bool TransposeC = false;
constexpr bool DoubleSmemBuffer =
PreshuffleB; // currently DoubleSmemBuffer is only supported for preshuffled B
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
@@ -90,7 +110,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kPadK,
false,
false,
PreshuffleB,
ALayout,
BLayout,
CLayout,
@@ -126,11 +146,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
BDataType,
scheduler>>::type;
using GemmPipeline = typename std::conditional<
QuantType == ck_tile::QuantType::BQuantGrouped,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>::type;
using GemmPipeline = std::conditional_t<
QuantType == ck_tile::QuantType::RowColQuant ||
QuantType == ck_tile::QuantType::TensorQuant,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>,
std::conditional_t<PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
@@ -344,7 +366,18 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
bq_tensors[i].get_element_space_size_in_bytes()));
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
if constexpr(PreshuffleB && QuantType == ck_tile::QuantType::BQuantGrouped)
{
auto b_shuffle_host =
ck_tile::shuffle_b<GroupedGemKernelParam_Mfma>(b_k_n_tensors[i]);
b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data());
}
else
{
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
}
aq_dev_buf[i]->ToDevice(aq_tensors[i].data());
bq_dev_buf[i]->ToDevice(bq_tensors[i].data());
c_m_n_dev_buf[i]->SetZero();
@@ -485,3 +518,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
EXPECT_TRUE(pass);
}
};
// Aliases for split test files
template <typename Tuple>
using TestCkTileGroupedGemmQuant_RowCol = TestCkTileGroupedGemmQuant<Tuple>;
template <typename Tuple>
using TestCkTileGroupedGemmQuant_Tensor = TestCkTileGroupedGemmQuant<Tuple>;
template <typename Tuple>
using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant<Tuple>;