diff --git a/test/data_type/test_mx_fp4.cpp b/test/data_type/test_mx_fp4.cpp index 30d955593e..6ba0bc9407 100644 --- a/test/data_type/test_mx_fp4.cpp +++ b/test/data_type/test_mx_fp4.cpp @@ -326,18 +326,21 @@ TEST(MXFP4, DeviceScaledConvert) __host__ __device__ float vec16_generator(ck::index_t i) { - return (i < 8 ? -1.0 : 1.0) * powf(2.0f, i % 8); + return type_convert(f4_t(i & 0b00001111)); } __host__ __device__ float vec32_generator(ck::index_t i) { if(i < 16) { - return vec16_generator(i % 16); + return vec16_generator( + i); // all positive values, then all negative values in ascending order } else { - return 1.5f * vec16_generator(i % 16); + return type_convert(ck::NumericLimits::Max()) - + vec16_generator( + i); // all positive values, then all negative values in descending order } } @@ -393,10 +396,12 @@ TEST(MXFP4, DeviceF32x32ToF4x32ScaledConvert) device_completed.FromDevice(&completed); device_out.FromDevice(out.data()); - auto i = 0; + auto i = 0; + auto scale2 = e8m0_bexp_t(2.0f); ck::static_for<0, N, 1>{}([&](auto ii) { - EXPECT_EQ(out[i++], vec32_generator(ii) / 2.0f) << "ii: " << ii << std::endl; + EXPECT_EQ(out[i++], vec32_generator(ii) / type_convert(scale2)) + << "ii: " << ii << std::endl; }); EXPECT_EQ(N, completed); @@ -455,10 +460,12 @@ TEST(MXFP4, DeviceF32x32ToF4x32ScaledConvertSR) device_completed.FromDevice(&completed); device_out.FromDevice(out.data()); - auto i = 0; + auto i = 0; + auto scale2 = e8m0_bexp_t(2.0f); ck::static_for<0, N, 1>{}([&](auto ii) { - EXPECT_EQ(out[i++], vec32_generator(ii) / 8.0f) << "ii: " << ii << std::endl; + EXPECT_EQ(out[i++], vec32_generator(ii) / type_convert(scale2)) + << "ii: " << ii << std::endl; }); EXPECT_EQ(N, completed); @@ -481,14 +488,14 @@ __global__ void test_mx_f32x32_device_scaled_convert(float* p_test, uint64_t* p_ return; } - auto scale2 = e8m0_bexp_t(4.0f); + auto scale2 = e8m0_bexp_t(2.0f); f4x32_t f4x32{}; float32_t float32{}; ck::static_for<0, N / 2, 1>{}([&](auto ii) { f4x32.AsType()(ck::Number{}) - .pack(type_convert(vec32_generator(2 * ii) / 16.0f), - type_convert(vec32_generator(2 * ii + 1) / 16.0f)); + .pack(type_convert(vec32_generator(2 * ii) / type_convert(scale2)), + type_convert(vec32_generator(2 * ii + 1) / type_convert(scale2))); }); float32 = scaled_type_convert(scale2, f4x32); @@ -515,10 +522,12 @@ TEST(MXFP4, DeviceF4x32ToF32x32ScaledConvert) device_completed.FromDevice(&completed); device_out.FromDevice(out.data()); - auto i = 0; + auto i = 0; + auto scale2 = e8m0_bexp_t(2.0f); ck::static_for<0, N, 1>{}([&](auto ii) { - EXPECT_EQ(out[i++], vec32_generator(ii) / 4.0f) << "ii: " << ii << std::endl; + EXPECT_EQ(out[i++], vec32_generator(ii) / type_convert(scale2)) + << "ii: " << ii << std::endl; }); EXPECT_EQ(N, completed); diff --git a/test/data_type/test_mx_fp4_repro.cpp b/test/data_type/test_mx_fp4_repro.cpp index a53c39ebdb..f705b774c4 100644 --- a/test/data_type/test_mx_fp4_repro.cpp +++ b/test/data_type/test_mx_fp4_repro.cpp @@ -85,15 +85,9 @@ TEST(MXFP4, FP4ToFP32) std::vector out(2, -1.0f); DeviceMem device_out(2 * sizeof(float)); - // DeviceMem device_completed(sizeof(uint64_t)); - - // device_out.SetValue(-21.0f); - // device_completed.SetValue(-21.0f); run_test_mx_fp4_to_fp32<<<1, 1>>>(static_cast(device_out.GetDeviceBuffer())); - // uint64_t completed = 0; - // device_completed.FromDevice(&completed); device_out.FromDevice(out.data()); // f4x2 -> f32x2 @@ -106,12 +100,9 @@ TEST(MXFP4, FP32ToFP4RNE) std::vector out(2, -1.0f); DeviceMem device_out(2 * sizeof(float)); - // DeviceMem device_completed(sizeof(uint64_t)); run_test_mx_fp32_to_fp4_rne<<<1, 1>>>(static_cast(device_out.GetDeviceBuffer())); - // uint64_t completed = 0; - // device_completed.FromDevice(&completed); device_out.FromDevice(out.data()); // f32x2 -> f4x2 @@ -125,12 +116,9 @@ TEST(MXFP4, FP32ToFP4SR) std::vector out(2, -1.0f); DeviceMem device_out(2 * sizeof(float)); - // DeviceMem device_completed(sizeof(uint64_t)); run_test_mx_fp32_to_fp4_sr<<<1, 1>>>(static_cast(device_out.GetDeviceBuffer())); - // uint64_t completed = 0; - // device_completed.FromDevice(&completed); device_out.FromDevice(out.data()); // SR @@ -143,12 +131,9 @@ TEST(MXFP4, FP32ToFP4SRFailing) std::vector out(2, -1.0f); DeviceMem device_out(2 * sizeof(float)); - // DeviceMem device_completed(sizeof(uint64_t)); run_test_mx_fp32_to_fp4_sr_failing<<<1, 1>>>(static_cast(device_out.GetDeviceBuffer())); - // uint64_t completed = 0; - // device_completed.FromDevice(&completed); device_out.FromDevice(out.data()); // SR