[CK_TILE] Add Bquant to Grouped Gemm (#3063)

* update test cases

* format codes

* use GTEST_FAIL

* add bquant to grouped_gemm

* fix a bug in test_grouped_gemm_util

* skip test when use wmma on grouped_quant kernel

* add tensorwise quant in grouped gemm

* fix example issue

* update test cases

* format codes

* fix a bug in test_grouped_gemm_util

* tests(quant_grouped_gemm): add unit tests to cover bquant in grouped_gemm

* Update test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* feat: add bf8 support

* chore: remove unnecessary decltype usage

* chore: add default quant_mode to function signature as fallback

* fix: pass correct runtime pipeline params in grouped_gemm bquant kernel

Calculate has_hot_loop, num_loop, and tail_number on device side for each
GEMM problem instead of using default values. This fixes incorrect results
when different problems in the group have different K dimensions.

* chore: set default quant mode in function signature

* test: add additional test cases to cover edge case of no hotloop

* chore: clang formatting

---------

Co-authored-by: kyle-256 <Kyle.Zhao@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Aviral Goel
2025-10-28 10:20:24 -04:00
committed by GitHub
parent 1c17bae816
commit 4368fd9f57
8 changed files with 276 additions and 104 deletions

View File

@@ -18,6 +18,7 @@ using True = ck_tile::bool_constant<true>;
using False = ck_tile::bool_constant<false>;
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
// clang-format off
using KernelTypes = ::testing::Types<
@@ -31,16 +32,16 @@ using KernelTypes = ::testing::Types<
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant>,
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant>
>;
// clang-format on

View File

@@ -26,3 +26,32 @@ TYPED_TEST(TestCkTileGroupedGemmQuant, Basic)
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count);
}
// No Hot Loop Test Case, this is to test the correctness of the kernel when there is no hot loop
// Using 256x256x128 to match the test kernel's tile size (M_Tile=256, N_Tile=256, K_Tile=128)
TYPED_TEST(TestCkTileGroupedGemmQuant, SmallUniform) //
{
const int group_count = 2;
std::vector<int> Ms;
std::vector<int> Ns;
std::vector<int> Ks;
std::vector<int> stride_As;
std::vector<int> stride_Bs;
std::vector<int> stride_Cs;
std::vector<int> stride_AQs;
std::vector<int> stride_BQs;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256);
Ns.push_back(256);
Ks.push_back(256);
stride_As.push_back(0);
stride_Bs.push_back(0);
stride_Cs.push_back(0);
stride_AQs.push_back(0);
stride_BQs.push_back(0);
}
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count);
}

View File

@@ -107,7 +107,15 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
constexpr bool transpose_c = false;
// We create the GEMM pipeline without specifying hotloop or tailnumber.
// These are automatically run inside the kernel based on the given input data.
using QuantGemmProblem =
using QuantGemmProblem = typename std::conditional<
QuantType == ck_tile::QuantType::BQuantGrouped,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
128>, // QuantGroupSize
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
@@ -116,9 +124,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
GemmUniversalTraits,
transpose_c,
BDataType,
scheduler>;
scheduler>>::type;
using GemmPipeline = typename std::conditional<
QuantType == ck_tile::QuantType::BQuantGrouped,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>::type;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
@@ -244,6 +256,15 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
}
else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped)
{
AQK = 0; // No A quantization
BQK = K / 128; // Group quantization: BQK = K / GroupSize
if(K % 128 != 0)
{
throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode");
}
}
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(ALayout{}));
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(BLayout{}));
@@ -258,7 +279,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
else if constexpr(QuantType == ck_tile::QuantType::TensorQuant)
{
stride_AQs[i] = 1; // Tensor quantization: tensor shape [1]
stride_AQs[i] = 1; // Tensor quantization: tensor shape [1]
stride_BQs[i] = 1; // Tensor quantization: tensor shape [1]
}
else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped)
{
stride_AQs[i] = 0; // No A quantization
stride_BQs[i] =
ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(BQLayout()));
}
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
@@ -285,6 +312,15 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
ck_tile::HostTensor<BQDataType>(ck_tile::host_tensor_descriptor(
1, 1, stride_BQs[i], is_row_major(BQLayout()))));
}
else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped)
{
aq_tensors.push_back(
ck_tile::HostTensor<AQDataType>(ck_tile::host_tensor_descriptor(
0, AQK, stride_AQs[i], is_row_major(AQLayout{}))));
bq_tensors.push_back(
ck_tile::HostTensor<BQDataType>(ck_tile::host_tensor_descriptor(
BQK, N, stride_BQs[i], is_row_major(BQLayout()))));
}
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
<< " b_k_n: " << b_k_n_tensors[i].mDesc
@@ -373,7 +409,6 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
hipMemcpyHostToDevice,
stream.stream_id_));
invoke_grouped_gemm_persistent<GroupedGemKernelParam_Mfma, ALayout, BLayout, CLayout>(
stream, group_count, kargs_ptr);
}
@@ -420,6 +455,17 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
bq_tensors[i],
c_m_n_host_ref);
}
else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped)
{
ck_tile::reference_gemm_quant<ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
128,
false>(
a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
}
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());