fix f4x2 implementation to match linear memory order

This commit is contained in:
aska-0096
2025-06-13 07:37:29 +00:00
parent d0757cccca
commit d57de07d68
5 changed files with 24 additions and 21 deletions

View File

@@ -248,7 +248,7 @@
#define CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION 1
// workaround: compiler issue on gfx950
#define CK_TEMP_DISABLE_FP4_TESTS 1
#define CK_TEMP_DISABLE_FP4_TESTS 0
// workaround: compiler issue on gfx950
#define CK_WORKAROUND_FP16_TO_FP8_CONVERSION 1

View File

@@ -51,14 +51,14 @@ struct f4x2_pk_t
{
static_assert(I < 2, "Index is out of range.");
if constexpr(I == 0)
return (data >> 4);
else
return data & 0b00001111;
else
return (data >> 4);
}
__host__ __device__ inline type pack(const type x0, const type x1)
{
return (x0 << 4) | (x1 & 0b00001111);
return (x1 << 4) | (x0 & 0b00001111);
}
};

View File

@@ -380,7 +380,7 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b
float2_t tmp =
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0);
// permute high bits and low bits to match the order of the original vector
return float2_t{tmp[1], tmp[0]};
return float2_t{tmp[0], tmp[1]};
#else
float2_t ret{utils::to_float<f4_t>(
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})),
@@ -408,8 +408,8 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], f_scale, 0);
// permute high bits and low bits to match the order of the original vector
ret[2 * idx] = op[1];
ret[2 * idx + 1] = op[0];
ret[2 * idx] = op[0];
ret[2 * idx + 1] = op[1];
});
return ret;

View File

@@ -1388,8 +1388,10 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f)
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
// If we keep origin order, error occured:
value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[0], x[1], scale, 0);
// permute high bits and low bits to match the order of the original vector
value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[1], x[0], scale, 0);
// value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[1], x[0], scale, 0);
return value.f4x2_array[0];
#else
union
@@ -1397,8 +1399,8 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f)
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
uint8_t l = utils::sat_convert_to_type<f4_t>(x[1] / scale);
uint8_t h = utils::sat_convert_to_type<f4_t>(x[0] / scale);
uint8_t l = utils::sat_convert_to_type<f4_t>(x[0] / scale);
uint8_t h = utils::sat_convert_to_type<f4_t>(x[1] / scale);
value.bitwise = (h << 4) | l;
return value.f4x2_array[0];
#endif
@@ -1418,7 +1420,7 @@ inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0
ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
// permute high bits and low bits to match the order of the original vector
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(
tmp_values.bitwise, x[2 * idx + 1], x[2 * idx], scale, 0);
tmp_values.bitwise, x[2 * idx], x[2 * idx + 1], scale, 0);
f4_values.f4x2_array[idx] = tmp_values.f4x2_array[0];
});
@@ -1489,13 +1491,13 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
} value{0};
// apply a temporary workaround for gfx950
#if CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION
uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
value.bitwise = (h << 4) | l;
#else
// permute high bits and low bits to match the order of the original vector
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float2_t{x[1], x[0]}, rng, scale, 0);
value.bitwise, float2_t{x[0], x[1]}, rng, scale, 0);
#endif // CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION
return value.f4x2_array[0];
#else
@@ -1504,8 +1506,8 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
value.bitwise = (h << 4) | l;
return value.f4x2_array[0];
#endif
@@ -1538,7 +1540,7 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f
// permute high bits and low bits to match the order of the original vector
f4_values.f4x2_array[idx] = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
f4_values.bitwise,
float2_t{float_values.floatx2_array[idx][1], float_values.floatx2_array[idx][0]},
float2_t{float_values.floatx2_array[idx][0], float_values.floatx2_array[idx][1]},
rng,
scale,
0);

View File

@@ -86,16 +86,17 @@ test_mx_fp4_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed)
auto scale2 = e8m0_bexp_t(2.0f);
float2_t f32x2 = scaled_type_convert<float2_t>(scale2, f4x2);
p_test[i++] = f32x2[0];
p_test[i++] = f32x2[0]; // 2* 0b1100(=-2.0) = -4.0
if(i >= N)
{
return;
}
p_test[i++] = f32x2[1];
p_test[i++] = f32x2[1]; // 2* 0b0001(=0.5) = 1.0
if(i >= N)
{
return;
}
// expected {-4, 1.0}
// f32x2 -> f4x2
f32x2 = {1.0f, -4.0f};
@@ -212,8 +213,8 @@ TEST(MXFP4, HostScaledConvert)
auto i = 256 * 16;
// f4x2 -> f32x2
EXPECT_EQ(out[i++], 1.0f);
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 1.0f);
// f32x2 -> f4x2
// RNE
@@ -297,8 +298,8 @@ TEST(MXFP4, DeviceScaledConvert)
auto i = 256 * 16;
// f4x2 -> f32x2
EXPECT_EQ(out[i++], 1.0f);
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 1.0f);
// f32x2 -> f4x2
// RNE