diff --git a/test/data_type/test_mx_fp4.cpp b/test/data_type/test_mx_fp4.cpp index 1509778533..449f6fc777 100644 --- a/test/data_type/test_mx_fp4.cpp +++ b/test/data_type/test_mx_fp4.cpp @@ -497,11 +497,11 @@ __global__ void test_mx_f32x32_device_scaled_convert(float* p_test, uint64_t* p_ f4x32_t f4x32{}; float32_t float32{}; ck::static_for<0, N / 2, 1>{}([&](auto ii) { - f4x32.AsType()(ck::Number{}) - .pack(type_convert(vec32_generator(2 * ii, type_convert(scale2)) / - type_convert(scale2)), - type_convert(vec32_generator(2 * ii + 1, type_convert(scale2)) / - type_convert(scale2))); + f4x32.AsType()(ck::Number{}) = f4x2_pk_t{}.pack( + type_convert(vec32_generator(2 * ii, type_convert(scale2)) / + type_convert(scale2)), + type_convert(vec32_generator(2 * ii + 1, type_convert(scale2)) / + type_convert(scale2))); }); float32 = scaled_type_convert(scale2, f4x32); @@ -532,8 +532,7 @@ TEST(MXFP4, DeviceF4x32ToF32x32ScaledConvert) auto scale2 = e8m0_bexp_t(2.0f); ck::static_for<0, N, 1>{}([&](auto ii) { - EXPECT_EQ(out[i++], - vec32_generator(ii, type_convert(scale2)) / type_convert(scale2)) + EXPECT_EQ(out[i++], vec32_generator(ii, type_convert(scale2))) << "ii: " << ii << std::endl; });