Update packed fp4 layout (#2523)

This commit is contained in:
Rostyslav Geyyer
2025-07-21 16:58:59 -05:00
committed by GitHub
parent 1fa1c34b7e
commit c9886109b4
4 changed files with 20 additions and 36 deletions

View File

@@ -50,7 +50,7 @@ struct f4x2_pk_t
__host__ __device__ inline type unpack(Number<I>) const
{
static_assert(I < 2, "Index is out of range.");
if constexpr(I == 0)
if constexpr(I == 1)
return (data >> 4);
else
return data & 0b00001111;
@@ -58,7 +58,7 @@ struct f4x2_pk_t
__host__ __device__ inline type pack(const type x0, const type x1)
{
return (x0 << 4) | (x1 & 0b00001111);
return (x1 << 4) | (x0 & 0b00001111);
}
// Compare operator

View File

@@ -377,10 +377,7 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b
f4x2_t f4x2_array[4];
} value{};
value.f4x2_array[0] = x;
float2_t tmp =
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0);
// permute high bits and low bits to match the order of the original vector
return float2_t{tmp[1], tmp[0]};
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0);
#else
float2_t ret{utils::to_float<f4_t>(
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})),
@@ -406,10 +403,9 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
float f_scale = type_convert<float>(scale);
ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], f_scale, 0);
// permute high bits and low bits to match the order of the original vector
ret[2 * idx] = op[1];
ret[2 * idx + 1] = op[0];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], f_scale, 0);
ret[2 * idx] = op[0];
ret[2 * idx + 1] = op[1];
});
return ret;

View File

@@ -1401,8 +1401,7 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f)
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
// permute high bits and low bits to match the order of the original vector
value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[1], x[0], scale, 0);
value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[0], x[1], scale, 0);
return value.f4x2_array[0];
#else
union
@@ -1410,8 +1409,8 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f)
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
uint8_t l = utils::sat_convert_to_type<f4_t>(x[1] / scale);
uint8_t h = utils::sat_convert_to_type<f4_t>(x[0] / scale);
uint8_t l = utils::sat_convert_to_type<f4_t>(x[0] / scale);
uint8_t h = utils::sat_convert_to_type<f4_t>(x[1] / scale);
value.bitwise = (h << 4) | l;
return value.f4x2_array[0];
#endif
@@ -1429,9 +1428,8 @@ inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0
} f4_values{}, tmp_values{};
ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
// permute high bits and low bits to match the order of the original vector
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(
tmp_values.bitwise, x[2 * idx + 1], x[2 * idx], scale, 0);
tmp_values.bitwise, x[2 * idx], x[2 * idx + 1], scale, 0);
f4_values.f4x2_array[idx] = tmp_values.f4x2_array[0];
});
@@ -1500,9 +1498,7 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
// permute high bits and low bits to match the order of the original vector
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
value.bitwise, float2_t{x[1], x[0]}, rng, scale, 0);
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(value.bitwise, x, rng, scale, 0);
return value.f4x2_array[0];
#else
constexpr int seed = 1254739;
@@ -1516,8 +1512,8 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
uint32_t bitwise;
f4x2_t f4x2_array[4];
} value{0};
uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
value.bitwise = (h << 4) | l;
return value.f4x2_array[0];
#endif
@@ -1544,13 +1540,8 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f
float_values.floatx32_array = x;
ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
// permute high bits and low bits to match the order of the original vector
f4_values.f4x2_array[idx] = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
f4_values.bitwise,
float2_t{float_values.floatx2_array[idx][1], float_values.floatx2_array[idx][0]},
rng,
scale,
0);
f4_values.bitwise, float_values.floatx2_array[idx], rng, scale, 0);
});
return f4_values.f4x32_array;
@@ -1648,9 +1639,7 @@ inline __host__ __device__ float2_t type_convert<float2_t, f4x2_t>(f4x2_t x)
} value{};
value.f4x2_array[0] = x;
float scale = 1.0f;
float2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0);
// permute high bits and low bits to match the order of the original vector
return float2_t{tmp[1], tmp[0]};
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0);
#else
float2_t ret{
utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
@@ -1676,10 +1665,9 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
float scale = 1.0f;
ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], scale, 0);
// permute high bits and low bits to match the order of the original vector
ret[2 * idx] = op[1];
ret[2 * idx + 1] = op[0];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], scale, 0);
ret[2 * idx] = op[0];
ret[2 * idx + 1] = op[1];
});
return ret;