From 83e2403545fe968cb01d53f0cb3b511efc841cbb Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Thu, 18 Sep 2025 09:12:37 -0500 Subject: [PATCH] Fix UB caused by reinterpret_cast (#2849) * Use bit_cast instead of reinterpret_cast to avoid UB * Apply same fix in ck_tile [ROCm/composable_kernel commit: 14bbc545ea672e66cdce00a3edbf4c532e2657e8] --- include/ck/utility/random_gen.hpp | 5 +++-- include/ck_tile/core/utility/random.hpp | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/include/ck/utility/random_gen.hpp b/include/ck/utility/random_gen.hpp index c37d3922ca..2ff46457fc 100644 --- a/include/ck/utility/random_gen.hpp +++ b/include/ck/utility/random_gen.hpp @@ -3,6 +3,7 @@ #pragma once #include +#include #include "ck/ck.hpp" #ifdef CK_CODE_GEN_RTC @@ -17,7 +18,7 @@ namespace ck { template {}, bool> = false> __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) { - uint32_t x = *(reinterpret_cast(&val)); + uint32_t x = bit_cast(val); uint32_t drop_bits = uint32_t(x) & 0xFFFFu; drop_bits ^= x >> 16; drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); @@ -33,7 +34,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = template {}, bool> = false> __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) { - uint16_t x = *(reinterpret_cast(&val)); + uint16_t x = bit_cast(val); uint32_t drop_bits = uint32_t(x) & 0xFFFFu; drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); drop_bits *= 0x7000149; diff --git a/include/ck_tile/core/utility/random.hpp b/include/ck_tile/core/utility/random.hpp index f7fbfad4dd..6a38ad3bde 100644 --- a/include/ck_tile/core/utility/random.hpp +++ b/include/ck_tile/core/utility/random.hpp @@ -24,7 +24,7 @@ struct prand_generator_t { CK_TILE_HOST_DEVICE uint32_t operator()(int id, float val, uint32_t seed = seed_) { - uint32_t x = *(reinterpret_cast(&val)); + uint32_t x = bit_cast(val); uint32_t drop_bits = uint32_t(x) & 0xFFFFu; drop_bits ^= x >> 16; drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); @@ -43,7 +43,7 @@ struct prand_generator_t { CK_TILE_HOST_DEVICE uint32_t operator()(int id, half_t val, uint32_t seed = seed_) { - uint16_t x = *(reinterpret_cast(&val)); + uint16_t x = bit_cast(val); uint32_t drop_bits = uint32_t(x) & 0xFFFFu; drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); drop_bits *= 0x7000149;