From c58afb0565775e8a0fee5fdcd6b14c4b04f4658d Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Mon, 21 Jul 2025 16:58:59 -0500 Subject: [PATCH] Update packed fp4 layout (#2523) [ROCm/composable_kernel commit: c9886109b43fdd73679c4443b6616a83eb40e066] --- include/ck/utility/data_type.hpp | 4 +-- include/ck/utility/scaled_type_convert.hpp | 12 +++----- include/ck/utility/type_convert.hpp | 36 ++++++++-------------- test/data_type/test_mx_fp4.cpp | 4 +-- 4 files changed, 20 insertions(+), 36 deletions(-) diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 8f5a45bdf0..5fbe30d21b 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -50,7 +50,7 @@ struct f4x2_pk_t __host__ __device__ inline type unpack(Number) const { static_assert(I < 2, "Index is out of range."); - if constexpr(I == 0) + if constexpr(I == 1) return (data >> 4); else return data & 0b00001111; @@ -58,7 +58,7 @@ struct f4x2_pk_t __host__ __device__ inline type pack(const type x0, const type x1) { - return (x0 << 4) | (x1 & 0b00001111); + return (x1 << 4) | (x0 & 0b00001111); } // Compare operator diff --git a/include/ck/utility/scaled_type_convert.hpp b/include/ck/utility/scaled_type_convert.hpp index 90a018fe3a..7de84d974c 100644 --- a/include/ck/utility/scaled_type_convert.hpp +++ b/include/ck/utility/scaled_type_convert.hpp @@ -377,10 +377,7 @@ inline __host__ __device__ float2_t scaled_type_convert(e8m0_b f4x2_t f4x2_array[4]; } value{}; value.f4x2_array[0] = x; - float2_t tmp = - __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert(scale), 0); - // permute high bits and low bits to match the order of the original vector - return float2_t{tmp[1], tmp[0]}; + return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert(scale), 0); #else float2_t ret{utils::to_float( scale, x.template AsType()[Number<0>{}].unpack<>(Number<0>{})), @@ -406,10 +403,9 @@ inline __host__ __device__ float32_t scaled_type_convert(e8m float f_scale = type_convert(scale); ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], f_scale, 0); - // permute high bits and low bits to match the order of the original vector - ret[2 * idx] = op[1]; - ret[2 * idx + 1] = op[0]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], f_scale, 0); + ret[2 * idx] = op[0]; + ret[2 * idx + 1] = op[1]; }); return ret; diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 05e461fa63..c859cfba3d 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -1401,8 +1401,7 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f) uint32_t bitwise; f4x2_t f4x2_array[4]; } value{0}; - // permute high bits and low bits to match the order of the original vector - value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[1], x[0], scale, 0); + value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[0], x[1], scale, 0); return value.f4x2_array[0]; #else union @@ -1410,8 +1409,8 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f) uint32_t bitwise; f4x2_t f4x2_array[4]; } value{0}; - uint8_t l = utils::sat_convert_to_type(x[1] / scale); - uint8_t h = utils::sat_convert_to_type(x[0] / scale); + uint8_t l = utils::sat_convert_to_type(x[0] / scale); + uint8_t h = utils::sat_convert_to_type(x[1] / scale); value.bitwise = (h << 4) | l; return value.f4x2_array[0]; #endif @@ -1429,9 +1428,8 @@ inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0 } f4_values{}, tmp_values{}; ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { - // permute high bits and low bits to match the order of the original vector tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32( - tmp_values.bitwise, x[2 * idx + 1], x[2 * idx], scale, 0); + tmp_values.bitwise, x[2 * idx], x[2 * idx + 1], scale, 0); f4_values.f4x2_array[idx] = tmp_values.f4x2_array[0]; }); @@ -1500,9 +1498,7 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) uint32_t bitwise; f4x2_t f4x2_array[4]; } value{0}; - // permute high bits and low bits to match the order of the original vector - value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - value.bitwise, float2_t{x[1], x[0]}, rng, scale, 0); + value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(value.bitwise, x, rng, scale, 0); return value.f4x2_array[0]; #else constexpr int seed = 1254739; @@ -1516,8 +1512,8 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) uint32_t bitwise; f4x2_t f4x2_array[4]; } value{0}; - uint8_t l = utils::sat_convert_to_type_sr(x[1] / scale, rng); - uint8_t h = utils::sat_convert_to_type_sr(x[0] / scale, rng); + uint8_t l = utils::sat_convert_to_type_sr(x[0] / scale, rng); + uint8_t h = utils::sat_convert_to_type_sr(x[1] / scale, rng); value.bitwise = (h << 4) | l; return value.f4x2_array[0]; #endif @@ -1544,13 +1540,8 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f float_values.floatx32_array = x; ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { - // permute high bits and low bits to match the order of the original vector f4_values.f4x2_array[idx] = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - f4_values.bitwise, - float2_t{float_values.floatx2_array[idx][1], float_values.floatx2_array[idx][0]}, - rng, - scale, - 0); + f4_values.bitwise, float_values.floatx2_array[idx], rng, scale, 0); }); return f4_values.f4x32_array; @@ -1648,9 +1639,7 @@ inline __host__ __device__ float2_t type_convert(f4x2_t x) } value{}; value.f4x2_array[0] = x; float scale = 1.0f; - float2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0); - // permute high bits and low bits to match the order of the original vector - return float2_t{tmp[1], tmp[0]}; + return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0); #else float2_t ret{ utils::to_float(NumericLimits::Binary_1(), @@ -1676,10 +1665,9 @@ inline __host__ __device__ float32_t type_convert(f4x32_t x) float scale = 1.0f; ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], scale, 0); - // permute high bits and low bits to match the order of the original vector - ret[2 * idx] = op[1]; - ret[2 * idx + 1] = op[0]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], scale, 0); + ret[2 * idx] = op[0]; + ret[2 * idx + 1] = op[1]; }); return ret; diff --git a/test/data_type/test_mx_fp4.cpp b/test/data_type/test_mx_fp4.cpp index 449f6fc777..c8059fa097 100644 --- a/test/data_type/test_mx_fp4.cpp +++ b/test/data_type/test_mx_fp4.cpp @@ -212,8 +212,8 @@ TEST(MXFP4, HostScaledConvert) auto i = 256 * 16; // f4x2 -> f32x2 - EXPECT_EQ(out[i++], 1.0f); EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 1.0f); // f32x2 -> f4x2 // RNE @@ -296,8 +296,8 @@ TEST(MXFP4, DeviceScaledConvert) auto i = 256 * 16; // f4x2 -> f32x2 - EXPECT_EQ(out[i++], 1.0f); EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 1.0f); // f32x2 -> f4x2 // RNE