diff --git a/test/data_type/test_mx_fp4.cpp b/test/data_type/test_mx_fp4.cpp index 6ba0bc9407..eb586a50fe 100644 --- a/test/data_type/test_mx_fp4.cpp +++ b/test/data_type/test_mx_fp4.cpp @@ -338,9 +338,8 @@ __host__ __device__ float vec32_generator(ck::index_t i) } else { - return type_convert(ck::NumericLimits::Max()) - - vec16_generator( - i); // all positive values, then all negative values in descending order + return vec16_generator( + 15 - (i % 16)); // all negative values, then all positive values in descending order } } @@ -371,9 +370,9 @@ __global__ void test_mx_fp4x32_device_scaled_convert(float* p_test, uint64_t* p_ ck::static_for<0, N / 2, 1>{}([&](auto ii) { p_test[i++] = type_convert( - f4x32.AsType()(ck::Number{}).template unpack<>(ck::Number<0>{})); + f4_t(f4x32.AsType()(ck::Number{}).template unpack<>(ck::Number<0>{}))); p_test[i++] = type_convert( - f4x32.AsType()(ck::Number{}).template unpack<>(ck::Number<1>{})); + f4_t(f4x32.AsType()(ck::Number{}).template unpack<>(ck::Number<1>{}))); }); } @@ -424,7 +423,7 @@ __global__ void test_mx_fp4x32_device_scaled_convert_sr(float* p_test, uint64_t* return; } - auto scale2 = e8m0_bexp_t(8.0f); + auto scale2 = e8m0_bexp_t(2.0f); f4x32_t f4x32{}; float32_t float32{};