From f918177301f803857a5e9701da787c55a4de2f00 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer Date: Wed, 12 Feb 2025 17:33:21 +0000 Subject: [PATCH] Permute packed f4_t values --- include/ck/utility/data_type.hpp | 6 +-- include/ck/utility/scaled_type_convert.hpp | 4 +- include/ck/utility/type_convert.hpp | 8 ++-- test/data_type/test_mx_fp4.cpp | 43 +++++----------------- 4 files changed, 19 insertions(+), 42 deletions(-) diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 522d5547d5..ee72b9bcd9 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -40,14 +40,14 @@ struct f4x2_pk_t { static_assert(I < 2, "Index is out of range."); if constexpr(I == 0) - return data & 0b00001111; - else return (data >> 4); + else + return data & 0b00001111; } __host__ __device__ inline type pack(const type x0, const type x1) { - return (x1 << 4) | (x0 & 0b00001111); + return (x0 << 4) | (x1 & 0b00001111); } }; diff --git a/include/ck/utility/scaled_type_convert.hpp b/include/ck/utility/scaled_type_convert.hpp index 5b7a822e1f..3ed458aa4d 100644 --- a/include/ck/utility/scaled_type_convert.hpp +++ b/include/ck/utility/scaled_type_convert.hpp @@ -380,9 +380,9 @@ inline __host__ __device__ float2_t scaled_type_convert(e8m0_b return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert(scale), 0); #else float2_t ret{utils::to_float( - scale, x.template AsType()[Number<0>{}].unpack<>(Number<1>{})), + scale, x.template AsType()[Number<0>{}].unpack<>(Number<0>{})), utils::to_float( - scale, x.template AsType()[Number<0>{}].unpack<>(Number<0>{}))}; + scale, x.template AsType()[Number<0>{}].unpack<>(Number<1>{}))}; return ret; #endif } diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index f14ff4b924..e9b2e3fff2 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -742,8 +742,8 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f) uint32_t bitwise; f4x2_t f4x2_array[4]; } value{0}; - uint8_t l = utils::sat_convert_to_type(x[0] / scale); - uint8_t h = utils::sat_convert_to_type(x[1] / scale); + uint8_t l = utils::sat_convert_to_type(x[1] / scale); + uint8_t h = utils::sat_convert_to_type(x[0] / scale); value.bitwise = (h << 4) | l; return value.f4x2_array[0]; #endif @@ -969,8 +969,8 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) uint32_t bitwise; f4x2_t f4x2_array[4]; } value{0}; - uint8_t l = utils::sat_convert_to_type_sr(x[0] / scale, rng); - uint8_t h = utils::sat_convert_to_type_sr(x[1] / scale, rng); + uint8_t l = utils::sat_convert_to_type_sr(x[1] / scale, rng); + uint8_t h = utils::sat_convert_to_type_sr(x[0] / scale, rng); value.bitwise = (h << 4) | l; return value.f4x2_array[0]; #endif diff --git a/test/data_type/test_mx_fp4.cpp b/test/data_type/test_mx_fp4.cpp index 065c30d23b..ff70d9a3c7 100644 --- a/test/data_type/test_mx_fp4.cpp +++ b/test/data_type/test_mx_fp4.cpp @@ -264,11 +264,6 @@ TEST(MXFP4, DeviceScaledConvert) device_completed.FromDevice(&completed); device_out.FromDevice(out.data()); - for(ck::index_t id = 0; id < 256 * 16; id++) - { - printf("%f\n", out.data()[id]); - } - // V = X * P; X - E8M0 scale, P - FP4 // If X = NaN, then V = NaN regardless of P @@ -279,32 +274,14 @@ TEST(MXFP4, DeviceScaledConvert) ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx]; } - // If P in {Inf, NaN}, then V = P - std::set fp4_nan_ids; - fp4_nan_ids.insert(0b11111111); //-NaN - fp4_nan_ids.insert(0b01111111); // +NaN for(ck::index_t exp_id = 0; exp_id < 256; exp_id++) { if(exp_id == e8m0_nan_id) continue; - for(auto fp4_nan_id : fp4_nan_ids) + for(ck::index_t fp4_id = 0; fp4_id < 16; fp4_id++) { - auto idx = exp_id * 256 + fp4_nan_id; - ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx]; - } - } - - for(ck::index_t exp_id = 0; exp_id < 256; exp_id++) - { - if(exp_id == e8m0_nan_id) - continue; - for(ck::index_t fp4_id = 0; fp4_id < 256; fp4_id++) - { - if(fp4_nan_ids.find(fp4_id) != fp4_nan_ids.end()) - continue; - uint8_t fp4_uid = static_cast(fp4_id); - auto idx = exp_id * 256 + fp4_uid; + auto idx = exp_id * 16 + fp4_uid; ASSERT_FLOAT_EQ(out[idx], type_convert(e8m0_bexp_t(exp_id)) * type_convert(f4_t(fp4_uid & 0b00001111))) @@ -319,19 +296,19 @@ TEST(MXFP4, DeviceScaledConvert) auto i = 256 * 16; // f4x2 -> f32x2 - EXPECT_EQ(out[i++], -powf(2.0f, -5.0f)); - EXPECT_EQ(out[i++], powf(2.0f, -8.0f)); + EXPECT_EQ(out[i++], 1.0f); + EXPECT_EQ(out[i++], -4.0f); // f32x2 -> f4x2 // RNE - EXPECT_EQ(out[i++], -4.0f); - EXPECT_EQ(out[i++], 2.0f); - // SR + EXPECT_EQ(out[i++], 0.5f); + EXPECT_EQ(out[i++], -2.0f); + // SR + EXPECT_EQ(out[i++], 0.5f); EXPECT_EQ(out[i++], -2.0f); - EXPECT_EQ(out[i++], 1.0f); /// Test round to nearest even - EXPECT_EQ(out[i++], 1024.0f / 4.0f) << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], 24.0f / 4.0f) << "out[i-1]: " << out[i - 1]; EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1]; #if 1 EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1]; @@ -347,7 +324,7 @@ TEST(MXFP4, DeviceScaledConvert) EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Lowest())) << "out[i-1]: " << out[i - 1]; #endif - EXPECT_EQ(out[i++], type_convert(type_convert(312.5f))) + EXPECT_EQ(out[i++], type_convert(type_convert(5.0f))) << "out[i-1]: " << out[i - 1]; EXPECT_EQ(test_size, completed);