mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Add tests
This commit is contained in:
@@ -153,7 +153,7 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType = std::conditional_t<
|
||||
std::is_same_v<BDataType_, ck_tile::pk_fp4_raw_t>,
|
||||
std::is_same_v<BDataType_, ck_tile::pk_fp4_t>,
|
||||
ADataType_,
|
||||
std::conditional_t<sizeof(ADataType_) < sizeof(BDataType_), ADataType_, BDataType_>>;
|
||||
// Calculate thresholds
|
||||
|
||||
@@ -14,8 +14,11 @@ using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using PkFP4 = ck_tile::pk_fp4_t;
|
||||
using E8M0 = ck_tile::e8m0_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
@@ -25,9 +28,12 @@ using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
// clang-format off
|
||||
using BQuant1D128Types = ::testing::Types<
|
||||
// 1d cases with grouping only on k axis
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -14,8 +14,11 @@ using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using PkFP4 = ck_tile::pk_fp4_t;
|
||||
using E8M0 = ck_tile::e8m0_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
|
||||
@@ -24,10 +27,13 @@ using GroupSize64 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuant1D64Types = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, PkFP4, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF8, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF16, BF16, E8M0, BF16, BQuantGrouped, GemmConfigMx, GroupSize64>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -102,7 +102,7 @@ struct GemmConfigDecodeInterwave : public GemmConfigBase
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
struct GemmConfigMxFp4 : public GemmConfigBase
|
||||
struct GemmConfigMx : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -662,7 +662,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
{
|
||||
const ck_tile::index_t stride_A = K;
|
||||
const ck_tile::index_t stride_B =
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t> ? (K / 2) : K;
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_t> ? (K / 2) : K;
|
||||
const ck_tile::index_t stride_C = N;
|
||||
|
||||
// BQuant uses block/grouped quantization for B matrix
|
||||
@@ -674,7 +674,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t> ? K / 2 : K,
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_t> ? K / 2 : K,
|
||||
N,
|
||||
stride_B,
|
||||
this->is_row_major(BLayout{})));
|
||||
@@ -683,14 +683,29 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
|
||||
// Initialize data with random values
|
||||
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<QDataType>{125.f, 130.f}(bq_bqk_bqn);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{0.f, 1.f}(b_k_n);
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<QDataType, ck_tile::e8m0_t>)
|
||||
{
|
||||
auto gen_scales = [&](auto& scales, float range_min, float range_max) {
|
||||
// e8m0_t is basically an exponent of float32
|
||||
ck_tile::HostTensor<float> pow2(scales.get_lengths());
|
||||
ck_tile::FillUniformDistributionIntegerValue<float>{range_min, range_max}(pow2);
|
||||
scales.ForEach([&](auto& self, const auto& i) {
|
||||
self(i) = static_cast<QDataType>(std::exp2(pow2(i)));
|
||||
});
|
||||
};
|
||||
gen_scales(bq_bqk_bqn, -2, 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{-1.0f, 1.0f}(bq_bqk_bqn);
|
||||
}
|
||||
|
||||
@@ -775,7 +790,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
// Run reference BQuant implementation
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
|
||||
if constexpr(std::is_same_v<QDataType, ck_tile::e8m0_t>)
|
||||
ck_tile::reference_mxfp4gemm_quant<ADataType,
|
||||
QDataType,
|
||||
BDataType,
|
||||
@@ -865,14 +880,14 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
PreshuffleB == false,
|
||||
std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
std::conditional_t<std::is_same_v<QDataType, ck_tile::e8m0_t>,
|
||||
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_t>,
|
||||
ADataType,
|
||||
BDataType>,
|
||||
ck_tile::tuple<>,
|
||||
|
||||
Reference in New Issue
Block a user