Simplify conversion function

This commit is contained in:
Rostyslav Geyyer
2025-02-19 20:39:02 +00:00
parent 50c1291317
commit bb953dad7e

View File

@@ -403,115 +403,72 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
f4x2_t fp4x2[16];
} value{};
value.f4x32_array = x;
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} bitwise_value{};
float2_t op;
float32_t ret;
// TODO: pack in a loop
// bitwise_value.f4x2_array[0] = value.fp4x2[0];
// op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
// bitwise_value.bitwise, type_convert<float>(scale), 0);
// ret[0] = op[0];
// ret[1] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[0], type_convert<float>(scale), 0);
ret[0] = op[0];
ret[1] = op[1];
ret[0] = op[1];
ret[1] = op[0];
// bitwise_value.f4x2_array[0] = value.fp4x2[1];
// op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
// bitwise_value.bitwise, type_convert<float>(scale), 0);
// ret[2] = op[0];
// ret[3] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[1], type_convert<float>(scale), 0);
ret[2] = op[0];
ret[3] = op[1];
ret[2] = op[1];
ret[3] = op[0];
bitwise_value.f4x2_array[0] = value.fp4x2[2];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[4] = op[0];
ret[5] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[2], type_convert<float>(scale), 0);
ret[4] = op[1];
ret[5] = op[0];
bitwise_value.f4x2_array[0] = value.fp4x2[3];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[6] = op[0];
ret[7] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[3], type_convert<float>(scale), 0);
ret[6] = op[1];
ret[7] = op[0];
bitwise_value.f4x2_array[0] = value.fp4x2[4];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[8] = op[0];
ret[9] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[4], type_convert<float>(scale), 0);
ret[8] = op[1];
ret[9] = op[0];
bitwise_value.f4x2_array[0] = value.fp4x2[5];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[10] = op[0];
ret[11] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[5], type_convert<float>(scale), 0);
ret[10] = op[1];
ret[11] = op[0];
bitwise_value.f4x2_array[0] = value.fp4x2[6];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[12] = op[0];
ret[13] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[6], type_convert<float>(scale), 0);
ret[12] = op[1];
ret[13] = op[0];
bitwise_value.f4x2_array[0] = value.fp4x2[7];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[14] = op[0];
ret[15] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[7], type_convert<float>(scale), 0);
ret[14] = op[1];
ret[15] = op[0];
bitwise_value.f4x2_array[0] = value.fp4x2[8];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[16] = op[0];
ret[17] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[8], type_convert<float>(scale), 0);
ret[16] = op[1];
ret[17] = op[0];
bitwise_value.f4x2_array[0] = value.fp4x2[9];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[18] = op[0];
ret[19] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[9], type_convert<float>(scale), 0);
ret[18] = op[1];
ret[19] = op[0];
bitwise_value.f4x2_array[0] = value.fp4x2[10];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[20] = op[0];
ret[21] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[10], type_convert<float>(scale), 0);
ret[20] = op[1];
ret[21] = op[0];
bitwise_value.f4x2_array[0] = value.fp4x2[11];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[22] = op[0];
ret[23] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[11], type_convert<float>(scale), 0);
ret[22] = op[1];
ret[23] = op[0];
bitwise_value.f4x2_array[0] = value.fp4x2[12];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[24] = op[0];
ret[25] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[12], type_convert<float>(scale), 0);
ret[24] = op[1];
ret[25] = op[0];
bitwise_value.f4x2_array[0] = value.fp4x2[13];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[26] = op[0];
ret[27] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[13], type_convert<float>(scale), 0);
ret[26] = op[1];
ret[27] = op[0];
bitwise_value.f4x2_array[0] = value.fp4x2[14];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[28] = op[0];
ret[29] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[14], type_convert<float>(scale), 0);
ret[28] = op[1];
ret[29] = op[0];
bitwise_value.f4x2_array[0] = value.fp4x2[15];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[30] = op[0];
ret[31] = op[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[15], type_convert<float>(scale), 0);
ret[30] = op[1];
ret[31] = op[0];
return ret;
#else