From e79e609696495c042c5b9ff02a8d68ba5bca363d Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Mon, 19 Jan 2026 16:50:35 +0000 Subject: [PATCH] Add tests --- .../gemm_block_scale/test_gemm_quant_base.hpp | 2 +- .../test_gemm_quant_bquant_1d_128.cpp | 12 +++++-- .../test_gemm_quant_bquant_1d_64.cpp | 14 ++++++--- .../test_gemm_quant_fixtures.hpp | 31 ++++++++++++++----- 4 files changed, 43 insertions(+), 16 deletions(-) diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 5937b44229..e12974d857 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -153,7 +153,7 @@ class TestCkTileGemmQuantBase : public ::testing::Test const float max_accumulated_value) { using ComputeType = std::conditional_t< - std::is_same_v, + std::is_same_v, ADataType_, std::conditional_t>; // Calculate thresholds diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp index d491d89ef4..d28fe62012 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp @@ -14,8 +14,11 @@ using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; using FP8 = ck_tile::fp8_t; using BF8 = ck_tile::bf8_t; +using BF16 = ck_tile::bf16_t; using Half = ck_tile::half_t; using PkInt4 = ck_tile::pk_int4_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; using BQuantGrouped = std::integral_constant; using GroupSize = ck_tile::QuantGroupShape>; @@ -25,9 +28,12 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using BQuant1D128Types = ::testing::Types< // 1d cases with grouping only on k axis - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp index 1019caf1bc..4965d48a34 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp @@ -14,8 +14,11 @@ using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; using FP8 = ck_tile::fp8_t; using BF8 = ck_tile::bf8_t; +using BF16 = ck_tile::bf16_t; using Half = ck_tile::half_t; using PkInt4 = ck_tile::pk_int4_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; using BQuantGrouped = std::integral_constant; using GroupSize64 = ck_tile::QuantGroupShape>; @@ -24,10 +27,13 @@ using GroupSize64 = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BQuant1D64Types = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index ca21bc69b7..c60901dad7 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -102,7 +102,7 @@ struct GemmConfigDecodeInterwave : public GemmConfigBase static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; -struct GemmConfigMxFp4 : public GemmConfigBase +struct GemmConfigMx : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -662,7 +662,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase ? (K / 2) : K; + std::is_same_v ? (K / 2) : K; const ck_tile::index_t stride_C = N; // BQuant uses block/grouped quantization for B matrix @@ -674,7 +674,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( - std::is_same_v ? K / 2 : K, + std::is_same_v ? K / 2 : K, N, stride_B, this->is_row_major(BLayout{}))); @@ -683,14 +683,29 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase{-0.5f, 0.5f}(a_m_k); - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { ck_tile::FillUniformDistribution{-5.0f, 5.0f}(b_k_n); - ck_tile::FillUniformDistribution{125.f, 130.f}(bq_bqk_bqn); } else { ck_tile::FillUniformDistribution{0.f, 1.f}(b_k_n); + } + + if constexpr(std::is_same_v) + { + auto gen_scales = [&](auto& scales, float range_min, float range_max) { + // e8m0_t is basically an exponent of float32 + ck_tile::HostTensor pow2(scales.get_lengths()); + ck_tile::FillUniformDistributionIntegerValue{range_min, range_max}(pow2); + scales.ForEach([&](auto& self, const auto& i) { + self(i) = static_cast(std::exp2(pow2(i))); + }); + }; + gen_scales(bq_bqk_bqn, -2, 2); + } + else + { ck_tile::FillUniformDistribution{-1.0f, 1.0f}(bq_bqk_bqn); } @@ -775,7 +790,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase) + if constexpr(std::is_same_v) ck_tile::reference_mxfp4gemm_quant, + std::conditional_t, ck_tile::MxFp4GemmPipelineAgBgCrCompV3, ck_tile::BQuantGemmPipelineAgBgCrCompV3>, ck_tile::WPQuantBPipelineAgBgCrV2>; using GemmEpilogue = ck_tile::CShuffleEpilogue, + std::conditional_t, ADataType, BDataType>, ck_tile::tuple<>,