test(grouped_gemm): add unit tests for grouped_gemm bquant with preshuffleB true (#3119)

* add tensorwise quant in grouped gemm

* fix example issue

* update test cases

* format codes

* clang format

* use GTEST_FAIL

* add bquant to grouped_gemm

* add tensorwise quant in grouped gemm

* fix example issue

* update test cases

* format codes

* clang format

* use GTEST_FAIL

* fix a bug in test_grouped_gemm_util

* skip test when use wmma on grouped_quant kernel

* change cmake

* fix a bug in test_grouped_gemm_util

* skip test when use wmma on grouped_quant kernel

* change cmake

* tests(quant_grouped_gemm): add unit tests to cover bquant in grouped_gemm

* Update test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* feat: add bf8 support

* chore: remove unnecessary decltype usage

* chore: add default quant_mode to function signature as fallback

* fix: pass correct runtime pipeline params in grouped_gemm bquant kernel

Calculate has_hot_loop, num_loop, and tail_number on device side for each
GEMM problem instead of using default values. This fixes incorrect results
when different problems in the group have different K dimensions.

* chore: set default quant mode in function signature

* test: add additional test cases to cover edge case of no hotloop

* change code based on comments

* WIP: bquant preshuffle b compiles but gives numerical error

* feat(grouped_gemm_quant): bquant with preshuffleB support added to grouped_gemm example & kernel

* refactor: refactor code after merge commit

* chore: remove print statements

* test(grouped_gemm): split test cases by quant mode to reduce compilation time and add bquant-preshuffleB mode test cases

---------

Co-authored-by: kyle-256 <Kyle.Zhao@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

[ROCm/composable_kernel commit: 8f1274d9b6]
This commit is contained in:
Aviral Goel
2025-10-31 15:07:06 -04:00
committed by GitHub
parent 27dc4d9833
commit 658fb530ab
14 changed files with 425 additions and 74 deletions

View File

@@ -49,7 +49,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
GemmConfig::kPadN,
GemmConfig::kPadK,
false, // PreshuffleQuant
false, // PreshuffleB
GemmConfig::PreshuffleB, // PreshuffleB
ALayout,
BLayout,
CLayout,
@@ -58,7 +58,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
BQLayout,
GemmConfig::TransposeC,
GemmConfig::DoubleSmemBuffer,
true>;
true>; // Persistence
float ave_time{0};
@@ -86,10 +86,14 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
BDataType,
scheduler>>::type;
using GemmPipeline =
typename std::conditional<QuantMode == ck_tile::QuantType::BQuantGrouped,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>::type;
using GemmPipeline = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>,
std::conditional_t<GemmConfig::PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
@@ -141,5 +145,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
int main(int argc, char* argv[])
{
return !run_grouped_gemm_example<GemmConfigComputeV3_2>(argc, argv);
int result1 = !run_grouped_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv);
return result1;
}

View File

@@ -10,9 +10,6 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_BQUANT_COMPUTE_V3 2
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
@@ -31,6 +28,22 @@ constexpr ck_tile::index_t get_k_warp_tile()
#endif
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
template <typename DataType>
struct GemmTypeConfig;
@@ -67,8 +80,9 @@ struct GemmConfigBase
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool PreshuffleB = false;
};
template <typename PrecType>
@@ -85,10 +99,26 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
};
static constexpr bool DoubleSmemBuffer = false;
template <typename PrecType>
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
};
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
@@ -118,7 +148,8 @@ auto create_args(int argc, char* argv[])
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "8", "group count.")
.insert("kbatch", "1", "kbatch for SplitK")
.insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol");
.insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol")
.insert("init", "0", "0. Random, 2. One(s) (Constant)");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -163,6 +163,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
const int kbatch = arg_parser.get_int("kbatch");
const int init_method = arg_parser.get_int("init");
bool validate = arg_parser.get_bool("validate");
const ck_tile::index_t QuantGroupSize = 128;
@@ -203,6 +204,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(256 + 512 * i);
Ks.push_back(512 + 128 * i);
@@ -280,6 +282,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
stride_AQs[i] = 1; // Tensor quantization: tensor shape [1]
stride_BQs[i] = 1; // Tensor quantization: tensor shape [1]
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
stride_AQs[i] = 0; // No A quantization
stride_BQs[i] =
ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(bq_layout));
}
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
@@ -313,10 +321,20 @@ int run_grouped_gemm_example_with_layouts(int argc,
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
<< " aq: " << aq_tensors[i].mDesc << " bq: " << bq_tensors[i].mDesc << std::endl;
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
ck_tile::FillUniformDistribution<AQDataType>{-1.f, 1.f}(aq_tensors[i]);
ck_tile::FillUniformDistribution<BQDataType>{-1.f, 1.f}(bq_tensors[i]);
if(init_method == 2)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n_tensors[i]);
ck_tile::FillUniformDistribution<AQDataType>{1.f, 1.f}(aq_tensors[i]);
ck_tile::FillUniformDistribution<BQDataType>{1.f, 1.f}(bq_tensors[i]);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
ck_tile::FillUniformDistribution<AQDataType>{-1.f, 1.f}(aq_tensors[i]);
ck_tile::FillUniformDistribution<BQDataType>{-1.f, 1.f}(bq_tensors[i]);
}
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
a_m_k_tensors[i].get_element_space_size_in_bytes()));
@@ -329,8 +347,18 @@ int run_grouped_gemm_example_with_layouts(int argc,
bq_dev_buf.push_back(
std::make_unique<ck_tile::DeviceMem>(bq_tensors[i].get_element_space_size_in_bytes()));
if constexpr(GemmConfig::PreshuffleB && QuantMode == ck_tile::QuantType::BQuantGrouped)
{
ck_tile::HostTensor<BDataType> b_shuffle_host =
ck_tile::shuffle_b<GemmConfig>(b_k_n_tensors[i]);
b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data());
}
else
{
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
}
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
aq_dev_buf[i]->ToDevice(aq_tensors[i].data());
bq_dev_buf[i]->ToDevice(bq_tensors[i].data());
c_m_n_dev_buf[i]->SetZero();

View File

@@ -20,7 +20,7 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}

View File

@@ -483,6 +483,7 @@ struct QuantGemmKernel
const QuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
@@ -790,6 +791,7 @@ struct QuantGemmKernel
}();
if constexpr(PreshuffleB)
{
return make_tuple(a_pad_view, aq_pad_view, b_flat_view, bq_pad_view, c_pad_view);
}
else
@@ -802,6 +804,7 @@ struct QuantGemmKernel
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& aq_pad_view = views.at(I1);
const auto& b_pad_view = views.at(I2);
@@ -867,6 +870,7 @@ struct QuantGemmKernel
const auto& b_block_window = [&]() {
if constexpr(PreshuffleB)
{
return make_tile_window(
b_pad_view,
make_tuple(number<GemmPipeline::flatNPerWarp>{},

View File

@@ -317,13 +317,88 @@ struct QuantGroupedGemmKernel
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
static_assert(GemmPipeline::DoubleSmemBuffer == false,
"DoubleSmemBuffer needs to be false");
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
RunGemmWithPipelineSelection(
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
// Only for BQuantGrouped DoubleSmemBuffer is supported
if constexpr(GemmPipeline::DoubleSmemBuffer == true &&
kQuantType == QuantType::BQuantGrouped)
{
__shared__ char smem_ptr_1[GetSmemSize()];
RunGemmWithPipelineSelection2LDS(a_ptr,
b_ptr,
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
RunGemmWithPipelineSelection(a_ptr,
b_ptr,
aq_ptr,
bq_ptr,
c_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void
RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr,
const BDataType* b_ptr,
const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
void* smem_ptr_1,
const QuantGroupedGemmKernelArgs& kargs,
const typename Base::SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
static_assert(kQuantType == QuantType::BQuantGrouped, "kQuantType must be BQuantGrouped");
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
const auto& b_block_window = gemm_tile_windows.at(Base::I2);
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
b_block_window,
bq_block_window,
num_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I4);
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
}
/**

View File

@@ -458,6 +458,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
void* p_smem_ping,
void* p_smem_pong) const
{
return operator()<TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
@@ -467,5 +468,31 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
p_smem_ping,
p_smem_pong);
}
template <typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename BQDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t num_loop,
TailNumber tail_number,
void* p_smem_ping,
void* p_smem_pong) const
{
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
(void)bool_val; // Suppress unused parameter warning
constexpr auto tail_num = tail_num_.value;
return operator()<tail_num>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_flat_dram_block_window_tmp,
bq_dram_block_window_tmp,
num_loop,
p_smem_ping,
p_smem_pong);
};
return Base::TailHandler(RunPipeline, true, tail_number);
}
};
} // namespace ck_tile

View File

@@ -4,7 +4,14 @@ if(CK_USE_OCP_FP8)
endif()
if(GPU_TARGETS MATCHES "gfx94|gfx95")
add_gtest_executable(test_ck_tile_grouped_gemm_quant test_grouped_gemm_quant.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
# Split into three separate test executables for faster parallel compilation
add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp)
target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()

View File

@@ -22,26 +22,28 @@ using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantTyp
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>,
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant>
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True>
>;
// clang-format on

View File

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_grouped_gemm_util_quant.hpp"
using F16 = ck_tile::half_t;
using F32 = float;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using True = ck_tile::bool_constant<true>;
using False = ck_tile::bool_constant<false>;
using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
// clang-format off
using KernelTypes_BQuant = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_BQuant, KernelTypes_BQuant);
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_BQuant
#include "test_grouped_gemm_quant_ut_cases.inc"
#undef TEST_CLASS_NAME

View File

@@ -0,0 +1,35 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_grouped_gemm_util_quant.hpp"
using F16 = ck_tile::half_t;
using F32 = float;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using True = ck_tile::bool_constant<true>;
using False = ck_tile::bool_constant<false>;
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
// clang-format off
using KernelTypes_RowCol = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_RowCol, KernelTypes_RowCol);
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_RowCol
#include "test_grouped_gemm_quant_ut_cases.inc"
#undef TEST_CLASS_NAME

View File

@@ -0,0 +1,35 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_grouped_gemm_util_quant.hpp"
using F16 = ck_tile::half_t;
using F32 = float;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using True = ck_tile::bool_constant<true>;
using False = ck_tile::bool_constant<false>;
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
// clang-format off
using KernelTypes_Tensor = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_Tensor, KernelTypes_Tensor);
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_Tensor
#include "test_grouped_gemm_quant_ut_cases.inc"
#undef TEST_CLASS_NAME

View File

@@ -1,6 +1,6 @@
#pragma once
TYPED_TEST(TestCkTileGroupedGemmQuant, Basic)
TYPED_TEST(TEST_CLASS_NAME, Basic)
{
const int group_count = 8;
std::vector<int> Ms;
@@ -29,7 +29,7 @@ TYPED_TEST(TestCkTileGroupedGemmQuant, Basic)
// No Hot Loop Test Case, this is to test the correctness of the kernel when there is no hot loop
// Using 256x256x128 to match the test kernel's tile size (M_Tile=256, N_Tile=256, K_Tile=128)
TYPED_TEST(TestCkTileGroupedGemmQuant, SmallUniform) //
TYPED_TEST(TEST_CLASS_NAME, SmallUniform) //
{
const int group_count = 2;
std::vector<int> Ms;
@@ -55,3 +55,29 @@ TYPED_TEST(TestCkTileGroupedGemmQuant, SmallUniform) //
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count);
}
TYPED_TEST(TEST_CLASS_NAME, OddTail) //
{
const int group_count = 2;
std::vector<int> Ms;
std::vector<int> Ns;
std::vector<int> Ks;
std::vector<int> stride_As;
std::vector<int> stride_Bs;
std::vector<int> stride_Cs;
std::vector<int> stride_AQs;
std::vector<int> stride_BQs;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256);
Ns.push_back(256);
Ks.push_back(128);
stride_As.push_back(0);
stride_Bs.push_back(0);
stride_Cs.push_back(0);
stride_AQs.push_back(0);
stride_BQs.push_back(0);
}
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count);
}

View File

@@ -17,23 +17,40 @@ template <typename Tuple>
class TestCkTileGroupedGemmQuant : public ::testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using AQDataType = std::tuple_element_t<4, Tuple>;
using BDataType = std::tuple_element_t<5, Tuple>;
using BQDataType = std::tuple_element_t<6, Tuple>;
using AccDataType = std::tuple_element_t<7, Tuple>;
using CDataType = std::tuple_element_t<8, Tuple>;
static constexpr auto QuantType = std::tuple_element_t<9, Tuple>::value;
using DsLayout = ck_tile::tuple<>;
using DsDataType = ck_tile::tuple<>;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using AQLayout = Row;
using BQLayout = Col;
static constexpr bool Persistent = true;
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using AQDataType = std::tuple_element_t<4, Tuple>;
using BDataType = std::tuple_element_t<5, Tuple>;
using BQDataType = std::tuple_element_t<6, Tuple>;
using AccDataType = std::tuple_element_t<7, Tuple>;
using CDataType = std::tuple_element_t<8, Tuple>;
static constexpr auto QuantType = std::tuple_element_t<9, Tuple>::value;
using DsLayout = ck_tile::tuple<>;
using DsDataType = ck_tile::tuple<>;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using AQLayout = Row;
using BQLayout = Col;
static constexpr bool Persistent = true;
static constexpr bool PreshuffleB = std::tuple_element_t<10, Tuple>::value;
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
static constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
struct GroupedGemKernelParam_Mfma
{
@@ -52,7 +69,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
static const ck_tile::index_t M_Warp_Tile = 32;
static const ck_tile::index_t N_Warp_Tile = 32;
static const ck_tile::index_t K_Warp_Tile = 16;
static const ck_tile::index_t K_Warp_Tile =
TestCkTileGroupedGemmQuant::template get_k_from_preshuffled_warp_tile<BDataType,
M_Warp_Tile>();
};
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
@@ -66,8 +85,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
const ck_tile::index_t num_groups,
void* kargs_ptr)
{
constexpr bool TransposeC = false;
constexpr bool DoubleSmemBuffer = false;
constexpr bool TransposeC = false;
constexpr bool DoubleSmemBuffer =
PreshuffleB; // currently DoubleSmemBuffer is only supported for preshuffled B
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
@@ -90,7 +110,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kPadK,
false,
false,
PreshuffleB,
ALayout,
BLayout,
CLayout,
@@ -126,11 +146,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
BDataType,
scheduler>>::type;
using GemmPipeline = typename std::conditional<
QuantType == ck_tile::QuantType::BQuantGrouped,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>::type;
using GemmPipeline = std::conditional_t<
QuantType == ck_tile::QuantType::RowColQuant ||
QuantType == ck_tile::QuantType::TensorQuant,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>,
std::conditional_t<PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
@@ -344,7 +366,18 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
bq_tensors[i].get_element_space_size_in_bytes()));
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
if constexpr(PreshuffleB && QuantType == ck_tile::QuantType::BQuantGrouped)
{
auto b_shuffle_host =
ck_tile::shuffle_b<GroupedGemKernelParam_Mfma>(b_k_n_tensors[i]);
b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data());
}
else
{
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
}
aq_dev_buf[i]->ToDevice(aq_tensors[i].data());
bq_dev_buf[i]->ToDevice(bq_tensors[i].data());
c_m_n_dev_buf[i]->SetZero();
@@ -485,3 +518,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
EXPECT_TRUE(pass);
}
};
// Aliases for split test files
template <typename Tuple>
using TestCkTileGroupedGemmQuant_RowCol = TestCkTileGroupedGemmQuant<Tuple>;
template <typename Tuple>
using TestCkTileGroupedGemmQuant_Tensor = TestCkTileGroupedGemmQuant<Tuple>;
template <typename Tuple>
using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant<Tuple>;