diff --git a/include/ck/utility/scaled_type_convert.hpp b/include/ck/utility/scaled_type_convert.hpp index ab65372fd5..099debc3d4 100644 --- a/include/ck/utility/scaled_type_convert.hpp +++ b/include/ck/utility/scaled_type_convert.hpp @@ -401,8 +401,7 @@ inline __host__ __device__ float32_t scaled_type_convert(e8m { f4x32_t f4x32_array; f4x2_t fp4x2[16]; - } value{}; - value.f4x32_array = x; + } value{x}; float2_t op; float32_t ret; float f_scale = type_convert(scale); diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index b1f6f0d15c..5f94a56636 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -1002,110 +1002,16 @@ inline __host__ __device__ float32_t type_convert(f4x32_t x) f4x32_t f4x32_array; f4x2_t fp4x2[16]; } value{x}; - union - { - uint32_t bitwise; - f4x2_t f4x2_array[4]; - } bitwise_value{}; float2_t op; float32_t ret; float scale = 1.0f; - // TODO: pack in a loop - bitwise_value.f4x2_array[0] = value.fp4x2[0]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[0] = op[0]; - ret[1] = op[1]; - bitwise_value.f4x2_array[0] = value.fp4x2[1]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[2] = op[0]; - ret[3] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[2]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[4] = op[0]; - ret[5] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[3]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[6] = op[0]; - ret[7] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[4]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[8] = op[0]; - ret[9] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[5]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[10] = op[0]; - ret[11] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[6]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[12] = op[0]; - ret[13] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[7]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[14] = op[0]; - ret[15] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[8]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[16] = op[0]; - ret[17] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[9]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[18] = op[0]; - ret[19] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[10]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[20] = op[0]; - ret[21] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[11]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[22] = op[0]; - ret[23] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[12]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[24] = op[0]; - ret[25] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[13]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[26] = op[0]; - ret[27] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[14]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[28] = op[0]; - ret[29] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[15]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[30] = op[0]; - ret[31] = op[1]; + ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], scale, 0); + // permute high bits and low bits to match the order of the original vector + ret[2 * idx] = op[1]; + ret[2 * idx + 1] = op[0]; + }); return ret; #else @@ -1120,106 +1026,18 @@ inline __host__ __device__ float32_t type_convert(f4x32_t x) f4x2_t f4x2_array[16]; f4x32_t f4x32_array; } f4_values{bit_cast<__uint128_t>(x)}; - // TODO: pack in a loop - float_values.float_array[0] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[0].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[1] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[0].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[2] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[1].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[3] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[1].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[4] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[2].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[5] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[2].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[6] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[3].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[7] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[3].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[0] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[4].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[1] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[4].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[2] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[5].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[3] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[5].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[4] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[6].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[5] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[6].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[6] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[7].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[7] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[7].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { + float_values.float_array[2 * idx] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[idx].template AsType()[Number<0>{}].template unpack<>( + Number<0>{})); - float_values.float_array[0] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[8].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[1] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[8].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[2] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[9].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[3] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[9].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[4] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[10].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[5] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[10].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[6] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[11].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[7] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[11].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - - float_values.float_array[0] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[12].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[1] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[12].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[2] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[13].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[3] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[13].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[4] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[14].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[5] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[14].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[6] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[15].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[7] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[15].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[2 * idx + 1] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[idx].template AsType()[Number<0>{}].template unpack<>( + Number<1>{})); + }); return float_values.float32_array; #endif