From bb953dad7eb9a883c496eea33c361e1fb89399c2 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer Date: Wed, 19 Feb 2025 20:39:02 +0000 Subject: [PATCH] Simplify conversion function --- include/ck/utility/scaled_type_convert.hpp | 135 +++++++-------------- 1 file changed, 46 insertions(+), 89 deletions(-) diff --git a/include/ck/utility/scaled_type_convert.hpp b/include/ck/utility/scaled_type_convert.hpp index 0a4ee4b28d..b810c881b2 100644 --- a/include/ck/utility/scaled_type_convert.hpp +++ b/include/ck/utility/scaled_type_convert.hpp @@ -403,115 +403,72 @@ inline __host__ __device__ float32_t scaled_type_convert(e8m f4x2_t fp4x2[16]; } value{}; value.f4x32_array = x; - union - { - uint32_t bitwise; - f4x2_t f4x2_array[4]; - } bitwise_value{}; float2_t op; float32_t ret; // TODO: pack in a loop - // bitwise_value.f4x2_array[0] = value.fp4x2[0]; - // op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - // bitwise_value.bitwise, type_convert(scale), 0); - // ret[0] = op[0]; - // ret[1] = op[1]; op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[0], type_convert(scale), 0); - ret[0] = op[0]; - ret[1] = op[1]; + ret[0] = op[1]; + ret[1] = op[0]; - // bitwise_value.f4x2_array[0] = value.fp4x2[1]; - // op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - // bitwise_value.bitwise, type_convert(scale), 0); - // ret[2] = op[0]; - // ret[3] = op[1]; op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[1], type_convert(scale), 0); - ret[2] = op[0]; - ret[3] = op[1]; + ret[2] = op[1]; + ret[3] = op[0]; - bitwise_value.f4x2_array[0] = value.fp4x2[2]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[4] = op[0]; - ret[5] = op[1]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[2], type_convert(scale), 0); + ret[4] = op[1]; + ret[5] = op[0]; - bitwise_value.f4x2_array[0] = value.fp4x2[3]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[6] = op[0]; - ret[7] = op[1]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[3], type_convert(scale), 0); + ret[6] = op[1]; + ret[7] = op[0]; - bitwise_value.f4x2_array[0] = value.fp4x2[4]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[8] = op[0]; - ret[9] = op[1]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[4], type_convert(scale), 0); + ret[8] = op[1]; + ret[9] = op[0]; - bitwise_value.f4x2_array[0] = value.fp4x2[5]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[10] = op[0]; - ret[11] = op[1]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[5], type_convert(scale), 0); + ret[10] = op[1]; + ret[11] = op[0]; - bitwise_value.f4x2_array[0] = value.fp4x2[6]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[12] = op[0]; - ret[13] = op[1]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[6], type_convert(scale), 0); + ret[12] = op[1]; + ret[13] = op[0]; - bitwise_value.f4x2_array[0] = value.fp4x2[7]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[14] = op[0]; - ret[15] = op[1]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[7], type_convert(scale), 0); + ret[14] = op[1]; + ret[15] = op[0]; - bitwise_value.f4x2_array[0] = value.fp4x2[8]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[16] = op[0]; - ret[17] = op[1]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[8], type_convert(scale), 0); + ret[16] = op[1]; + ret[17] = op[0]; - bitwise_value.f4x2_array[0] = value.fp4x2[9]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[18] = op[0]; - ret[19] = op[1]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[9], type_convert(scale), 0); + ret[18] = op[1]; + ret[19] = op[0]; - bitwise_value.f4x2_array[0] = value.fp4x2[10]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[20] = op[0]; - ret[21] = op[1]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[10], type_convert(scale), 0); + ret[20] = op[1]; + ret[21] = op[0]; - bitwise_value.f4x2_array[0] = value.fp4x2[11]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[22] = op[0]; - ret[23] = op[1]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[11], type_convert(scale), 0); + ret[22] = op[1]; + ret[23] = op[0]; - bitwise_value.f4x2_array[0] = value.fp4x2[12]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[24] = op[0]; - ret[25] = op[1]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[12], type_convert(scale), 0); + ret[24] = op[1]; + ret[25] = op[0]; - bitwise_value.f4x2_array[0] = value.fp4x2[13]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[26] = op[0]; - ret[27] = op[1]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[13], type_convert(scale), 0); + ret[26] = op[1]; + ret[27] = op[0]; - bitwise_value.f4x2_array[0] = value.fp4x2[14]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[28] = op[0]; - ret[29] = op[1]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[14], type_convert(scale), 0); + ret[28] = op[1]; + ret[29] = op[0]; - bitwise_value.f4x2_array[0] = value.fp4x2[15]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[30] = op[0]; - ret[31] = op[1]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[15], type_convert(scale), 0); + ret[30] = op[1]; + ret[31] = op[0]; return ret; #else