From 5f8f190b34a3e14feb87f749d1a0a0fa8f5b62c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Manuel=20Martinez=20Caama=C3=B1o?= Date: Wed, 7 Aug 2024 20:49:02 +0200 Subject: [PATCH] Remove reinterpret_cast uses that result in undefined behaviour. (#1445) * Remove reinterpret_cast uses that result in undefined behaviour. Use a bitcast instead. See https://en.cppreference.com/w/cpp/language/reinterpret_cast#Type_accessibility Closes #1439 * fix clang format --------- Co-authored-by: illsilin [ROCm/composable_kernel commit: 901e5f1540371e475aa9472b9105bb0152a58880] --- include/ck/utility/f8_utils.hpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp index 98e8092af5..2533073225 100644 --- a/include/ck/utility/f8_utils.hpp +++ b/include/ck/utility/f8_utils.hpp @@ -44,7 +44,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) // convert to bitwise using T_bitwise = typename NumericUtils::bitwise_type; - T_bitwise x_bitwise = *(reinterpret_cast(&x)); + T_bitwise x_bitwise = bit_cast(x); // unpack the input, depends on datatype head = x_bitwise & NumericUtils::head_mask; @@ -196,18 +196,17 @@ __host__ __device__ Y run_cast_from_f8(X x) // prepare the codes constexpr X nan_code = 0x80; - Y Inf, NegInf, NaN, Neg0; - using T_bitwise = typename NumericUtils::bitwise_type; + using T_bitwise = typename NumericUtils::bitwise_type; constexpr T_bitwise Inf_bitwise = NumericUtils::Inf; constexpr T_bitwise NegInf_bitwise = NumericUtils::NegInf; constexpr T_bitwise NaN_bitwise = NumericUtils::NaN; constexpr T_bitwise Neg0_bitwise = NumericUtils::Neg0; - Inf = *(reinterpret_cast(&Inf_bitwise)); - NegInf = *(reinterpret_cast(&NegInf_bitwise)); - NaN = *(reinterpret_cast(&NaN_bitwise)); - Neg0 = *(reinterpret_cast(&Neg0_bitwise)); + constexpr Y Inf = bit_cast(Inf_bitwise); + constexpr Y NegInf = bit_cast(NegInf_bitwise); + constexpr Y NaN = bit_cast(NaN_bitwise); + constexpr Y Neg0 = bit_cast(Neg0_bitwise); // check if x is 0.0 if(x == 0) @@ -240,7 +239,7 @@ __host__ __device__ Y run_cast_from_f8(X x) { retval = x; retval <<= 8; - return *(reinterpret_cast(&retval)); + return bit_cast(retval); } // subnormal input @@ -264,7 +263,7 @@ __host__ __device__ Y run_cast_from_f8(X x) } retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa; - return *(reinterpret_cast(&retval)); + return bit_cast(retval); } } // namespace