Add tests

This commit is contained in:
Enrico Degregori
2026-01-19 16:50:35 +00:00
parent 8684a671fc
commit e79e609696
4 changed files with 43 additions and 16 deletions

View File

@@ -153,7 +153,7 @@ class TestCkTileGemmQuantBase : public ::testing::Test
const float max_accumulated_value)
{
using ComputeType = std::conditional_t<
std::is_same_v<BDataType_, ck_tile::pk_fp4_raw_t>,
std::is_same_v<BDataType_, ck_tile::pk_fp4_t>,
ADataType_,
std::conditional_t<sizeof(ADataType_) < sizeof(BDataType_), ADataType_, BDataType_>>;
// Calculate thresholds

View File

@@ -14,8 +14,11 @@ using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using BF16 = ck_tile::bf16_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using PkFP4 = ck_tile::pk_fp4_t;
using E8M0 = ck_tile::e8m0_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
@@ -25,9 +28,12 @@ using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// clang-format off
using BQuant1D128Types = ::testing::Types<
// 1d cases with grouping only on k axis
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize>
>;
// clang-format on

View File

@@ -14,8 +14,11 @@ using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using BF16 = ck_tile::bf16_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using PkFP4 = ck_tile::pk_fp4_t;
using E8M0 = ck_tile::e8m0_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
@@ -24,10 +27,13 @@ using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BQuant1D64Types = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>
>;
// clang-format on

View File

@@ -102,7 +102,7 @@ struct GemmConfigDecodeInterwave : public GemmConfigBase
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
struct GemmConfigMxFp4 : public GemmConfigBase
struct GemmConfigMx : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -662,7 +662,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B =
std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t> ? (K / 2) : K;
std::is_same_v<BDataType, ck_tile::pk_fp4_t> ? (K / 2) : K;
const ck_tile::index_t stride_C = N;
// BQuant uses block/grouped quantization for B matrix
@@ -674,7 +674,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t> ? K / 2 : K,
std::is_same_v<BDataType, ck_tile::pk_fp4_t> ? K / 2 : K,
N,
stride_B,
this->is_row_major(BLayout{})));
@@ -683,14 +683,29 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
// Initialize data with random values
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_t>)
{
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f}(b_k_n);
ck_tile::FillUniformDistribution<QDataType>{125.f, 130.f}(bq_bqk_bqn);
}
else
{
ck_tile::FillUniformDistribution<BDataType>{0.f, 1.f}(b_k_n);
}
if constexpr(std::is_same_v<QDataType, ck_tile::e8m0_t>)
{
auto gen_scales = [&](auto& scales, float range_min, float range_max) {
// e8m0_t is basically an exponent of float32
ck_tile::HostTensor<float> pow2(scales.get_lengths());
ck_tile::FillUniformDistributionIntegerValue<float>{range_min, range_max}(pow2);
scales.ForEach([&](auto& self, const auto& i) {
self(i) = static_cast<QDataType>(std::exp2(pow2(i)));
});
};
gen_scales(bq_bqk_bqn, -2, 2);
}
else
{
ck_tile::FillUniformDistribution<QDataType>{-1.0f, 1.0f}(bq_bqk_bqn);
}
@@ -775,7 +790,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
c_m_n_host_ref.SetZero();
// Run reference BQuant implementation
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
if constexpr(std::is_same_v<QDataType, ck_tile::e8m0_t>)
ck_tile::reference_mxfp4gemm_quant<ADataType,
QDataType,
BDataType,
@@ -865,14 +880,14 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
using GemmPipeline = std::conditional_t<
PreshuffleB == false,
std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>,
std::conditional_t<std::is_same_v<QDataType, ck_tile::e8m0_t>,
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>,
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
ADataType,
std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>,
std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_t>,
ADataType,
BDataType>,
ck_tile::tuple<>,