diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index e9b2e3fff2..552b211821 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -734,7 +734,7 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f) uint32_t bitwise; f4x2_t f4x2_array[4]; } value{0}; - value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[0], x[1], scale, 0); + value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[1], x[0], scale, 0); return value.f4x2_array[0]; #else union @@ -961,6 +961,7 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) uint32_t bitwise; f4x2_t f4x2_array[4]; } value{0}; + printf("%f, %f\n", x[0], x[1]); value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(value.bitwise, x, rng, scale, 0); return value.f4x2_array[0]; #else diff --git a/test/data_type/test_mx_fp4.cpp b/test/data_type/test_mx_fp4.cpp index ff70d9a3c7..30d955593e 100644 --- a/test/data_type/test_mx_fp4.cpp +++ b/test/data_type/test_mx_fp4.cpp @@ -309,21 +309,14 @@ TEST(MXFP4, DeviceScaledConvert) /// Test round to nearest even EXPECT_EQ(out[i++], 24.0f / 4.0f) << "out[i-1]: " << out[i - 1]; - EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1]; -#if 1 - EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1]; - EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1]; - EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1]; -#else - // NOTE: Host and Device have different behavior. - // Device returns NaN, while Host returns Max (saturation to finite value). + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Max())) + << "out[i-1]: " << out[i - 1]; EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Max())) << "out[i-1]: " << out[i - 1]; EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Max())) << "out[i-1]: " << out[i - 1]; EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Lowest())) << "out[i-1]: " << out[i - 1]; -#endif EXPECT_EQ(out[i++], type_convert(type_convert(5.0f))) << "out[i-1]: " << out[i - 1];