From e323d613ffeafc1057247ca123a480a2344de454 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer Date: Tue, 18 Feb 2025 19:47:36 +0000 Subject: [PATCH] Update test vector generator --- test/data_type/test_mx_fp4.cpp | 36 +++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/test/data_type/test_mx_fp4.cpp b/test/data_type/test_mx_fp4.cpp index eb586a50fe..e12fe2de47 100644 --- a/test/data_type/test_mx_fp4.cpp +++ b/test/data_type/test_mx_fp4.cpp @@ -324,22 +324,23 @@ TEST(MXFP4, DeviceScaledConvert) EXPECT_EQ(test_size, i); } -__host__ __device__ float vec16_generator(ck::index_t i) +__host__ __device__ float vec16_generator(ck::index_t i, float scale) { - return type_convert(f4_t(i & 0b00001111)); + return scale * type_convert(f4_t(i & 0b00001111)); } -__host__ __device__ float vec32_generator(ck::index_t i) +__host__ __device__ float vec32_generator(ck::index_t i, float scale) { if(i < 16) { return vec16_generator( - i); // all positive values, then all negative values in ascending order + i, scale); // all positive values, then all negative values in ascending order } else { return vec16_generator( - 15 - (i % 16)); // all negative values, then all positive values in descending order + 15 - (i % 16), + scale); // all negative values, then all positive values in descending order } } @@ -363,8 +364,9 @@ __global__ void test_mx_fp4x32_device_scaled_convert(float* p_test, uint64_t* p_ f4x32_t f4x32{}; float32_t float32{}; - ck::static_for<0, N, 1>{}( - [&](auto ii) { float32[static_cast(ii)] = vec32_generator(ii); }); + ck::static_for<0, N, 1>{}([&](auto ii) { + float32[static_cast(ii)] = vec32_generator(ii, type_convert(scale2)); + }); f4x32 = f4_convert_rne(float32, type_convert(scale2)); @@ -399,7 +401,8 @@ TEST(MXFP4, DeviceF32x32ToF4x32ScaledConvert) auto scale2 = e8m0_bexp_t(2.0f); ck::static_for<0, N, 1>{}([&](auto ii) { - EXPECT_EQ(out[i++], vec32_generator(ii) / type_convert(scale2)) + EXPECT_EQ(out[i++], + vec32_generator(ii, type_convert(scale2)) / type_convert(scale2)) << "ii: " << ii << std::endl; }); @@ -427,8 +430,9 @@ __global__ void test_mx_fp4x32_device_scaled_convert_sr(float* p_test, uint64_t* f4x32_t f4x32{}; float32_t float32{}; - ck::static_for<0, N, 1>{}( - [&](auto ii) { float32[static_cast(ii)] = vec32_generator(ii); }); + ck::static_for<0, N, 1>{}([&](auto ii) { + float32[static_cast(ii)] = vec32_generator(ii, type_convert(scale2)); + }); f4x32 = f4_convert_sr(float32, type_convert(scale2)); @@ -463,7 +467,8 @@ TEST(MXFP4, DeviceF32x32ToF4x32ScaledConvertSR) auto scale2 = e8m0_bexp_t(2.0f); ck::static_for<0, N, 1>{}([&](auto ii) { - EXPECT_EQ(out[i++], vec32_generator(ii) / type_convert(scale2)) + EXPECT_EQ(out[i++], + vec32_generator(ii, type_convert(scale2)) / type_convert(scale2)) << "ii: " << ii << std::endl; }); @@ -493,8 +498,10 @@ __global__ void test_mx_f32x32_device_scaled_convert(float* p_test, uint64_t* p_ float32_t float32{}; ck::static_for<0, N / 2, 1>{}([&](auto ii) { f4x32.AsType()(ck::Number{}) - .pack(type_convert(vec32_generator(2 * ii) / type_convert(scale2)), - type_convert(vec32_generator(2 * ii + 1) / type_convert(scale2))); + .pack(type_convert(vec32_generator(2 * ii, type_convert(scale2)) / + type_convert(scale2)), + type_convert(vec32_generator(2 * ii + 1, type_convert(scale2)) / + type_convert(scale2))); }); float32 = scaled_type_convert(scale2, f4x32); @@ -525,7 +532,8 @@ TEST(MXFP4, DeviceF4x32ToF32x32ScaledConvert) auto scale2 = e8m0_bexp_t(2.0f); ck::static_for<0, N, 1>{}([&](auto ii) { - EXPECT_EQ(out[i++], vec32_generator(ii) / type_convert(scale2)) + EXPECT_EQ(out[i++], + vec32_generator(ii, type_convert(scale2)) / type_convert(scale2)) << "ii: " << ii << std::endl; });