From d57de07d68ace505bcfb2210ccc1c40d4a376ffe Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 13 Jun 2025 07:37:29 +0000 Subject: [PATCH] fix f4x2 implementation to match linear memory order --- include/ck/ck.hpp | 2 +- include/ck/utility/data_type.hpp | 6 +++--- include/ck/utility/scaled_type_convert.hpp | 6 +++--- include/ck/utility/type_convert.hpp | 22 ++++++++++++---------- test/data_type/test_mx_fp4.cpp | 9 +++++---- 5 files changed, 24 insertions(+), 21 deletions(-) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 26e4787949..e6ec4c897c 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -248,7 +248,7 @@ #define CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION 1 // workaround: compiler issue on gfx950 -#define CK_TEMP_DISABLE_FP4_TESTS 1 +#define CK_TEMP_DISABLE_FP4_TESTS 0 // workaround: compiler issue on gfx950 #define CK_WORKAROUND_FP16_TO_FP8_CONVERSION 1 diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 51da18cd2b..f3ea37b07c 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -51,14 +51,14 @@ struct f4x2_pk_t { static_assert(I < 2, "Index is out of range."); if constexpr(I == 0) - return (data >> 4); - else return data & 0b00001111; + else + return (data >> 4); } __host__ __device__ inline type pack(const type x0, const type x1) { - return (x0 << 4) | (x1 & 0b00001111); + return (x1 << 4) | (x0 & 0b00001111); } }; diff --git a/include/ck/utility/scaled_type_convert.hpp b/include/ck/utility/scaled_type_convert.hpp index f3e2bd3dd9..a3995dd50a 100644 --- a/include/ck/utility/scaled_type_convert.hpp +++ b/include/ck/utility/scaled_type_convert.hpp @@ -380,7 +380,7 @@ inline __host__ __device__ float2_t scaled_type_convert(e8m0_b float2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert(scale), 0); // permute high bits and low bits to match the order of the original vector - return float2_t{tmp[1], tmp[0]}; + return float2_t{tmp[0], tmp[1]}; #else float2_t ret{utils::to_float( scale, x.template AsType()[Number<0>{}].unpack<>(Number<0>{})), @@ -408,8 +408,8 @@ inline __host__ __device__ float32_t scaled_type_convert(e8m ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], f_scale, 0); // permute high bits and low bits to match the order of the original vector - ret[2 * idx] = op[1]; - ret[2 * idx + 1] = op[0]; + ret[2 * idx] = op[0]; + ret[2 * idx + 1] = op[1]; }); return ret; diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 5865f1dd78..0851d68e8b 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -1388,8 +1388,10 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f) uint32_t bitwise; f4x2_t f4x2_array[4]; } value{0}; + // If we keep origin order, error occured: + value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[0], x[1], scale, 0); // permute high bits and low bits to match the order of the original vector - value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[1], x[0], scale, 0); + // value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[1], x[0], scale, 0); return value.f4x2_array[0]; #else union @@ -1397,8 +1399,8 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f) uint32_t bitwise; f4x2_t f4x2_array[4]; } value{0}; - uint8_t l = utils::sat_convert_to_type(x[1] / scale); - uint8_t h = utils::sat_convert_to_type(x[0] / scale); + uint8_t l = utils::sat_convert_to_type(x[0] / scale); + uint8_t h = utils::sat_convert_to_type(x[1] / scale); value.bitwise = (h << 4) | l; return value.f4x2_array[0]; #endif @@ -1418,7 +1420,7 @@ inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0 ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { // permute high bits and low bits to match the order of the original vector tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32( - tmp_values.bitwise, x[2 * idx + 1], x[2 * idx], scale, 0); + tmp_values.bitwise, x[2 * idx], x[2 * idx + 1], scale, 0); f4_values.f4x2_array[idx] = tmp_values.f4x2_array[0]; }); @@ -1489,13 +1491,13 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) } value{0}; // apply a temporary workaround for gfx950 #if CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION - uint8_t l = utils::sat_convert_to_type_sr(x[1] / scale, rng); - uint8_t h = utils::sat_convert_to_type_sr(x[0] / scale, rng); + uint8_t l = utils::sat_convert_to_type_sr(x[0] / scale, rng); + uint8_t h = utils::sat_convert_to_type_sr(x[1] / scale, rng); value.bitwise = (h << 4) | l; #else // permute high bits and low bits to match the order of the original vector value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - value.bitwise, float2_t{x[1], x[0]}, rng, scale, 0); + value.bitwise, float2_t{x[0], x[1]}, rng, scale, 0); #endif // CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION return value.f4x2_array[0]; #else @@ -1504,8 +1506,8 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) uint32_t bitwise; f4x2_t f4x2_array[4]; } value{0}; - uint8_t l = utils::sat_convert_to_type_sr(x[1] / scale, rng); - uint8_t h = utils::sat_convert_to_type_sr(x[0] / scale, rng); + uint8_t l = utils::sat_convert_to_type_sr(x[0] / scale, rng); + uint8_t h = utils::sat_convert_to_type_sr(x[1] / scale, rng); value.bitwise = (h << 4) | l; return value.f4x2_array[0]; #endif @@ -1538,7 +1540,7 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f // permute high bits and low bits to match the order of the original vector f4_values.f4x2_array[idx] = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( f4_values.bitwise, - float2_t{float_values.floatx2_array[idx][1], float_values.floatx2_array[idx][0]}, + float2_t{float_values.floatx2_array[idx][0], float_values.floatx2_array[idx][1]}, rng, scale, 0); diff --git a/test/data_type/test_mx_fp4.cpp b/test/data_type/test_mx_fp4.cpp index 7aca42567c..13b1b3ccc8 100644 --- a/test/data_type/test_mx_fp4.cpp +++ b/test/data_type/test_mx_fp4.cpp @@ -86,16 +86,17 @@ test_mx_fp4_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed) auto scale2 = e8m0_bexp_t(2.0f); float2_t f32x2 = scaled_type_convert(scale2, f4x2); - p_test[i++] = f32x2[0]; + p_test[i++] = f32x2[0]; // 2* 0b1100(=-2.0) = -4.0 if(i >= N) { return; } - p_test[i++] = f32x2[1]; + p_test[i++] = f32x2[1]; // 2* 0b0001(=0.5) = 1.0 if(i >= N) { return; } + // expected {-4, 1.0} // f32x2 -> f4x2 f32x2 = {1.0f, -4.0f}; @@ -212,8 +213,8 @@ TEST(MXFP4, HostScaledConvert) auto i = 256 * 16; // f4x2 -> f32x2 - EXPECT_EQ(out[i++], 1.0f); EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 1.0f); // f32x2 -> f4x2 // RNE @@ -297,8 +298,8 @@ TEST(MXFP4, DeviceScaledConvert) auto i = 256 * 16; // f4x2 -> f32x2 - EXPECT_EQ(out[i++], 1.0f); EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 1.0f); // f32x2 -> f4x2 // RNE