Fix UB caused by reinterpret_cast (#2849)

* Use bit_cast instead of reinterpret_cast to avoid UB

* Apply same fix in ck_tile

[ROCm/composable_kernel commit: 14bbc545ea]
This commit is contained in:
Rostyslav Geyyer
2025-09-18 09:12:37 -05:00
committed by GitHub
parent d5ff6cb785
commit 5ee7f320a0
2 changed files with 5 additions and 4 deletions

View File

@@ -3,6 +3,7 @@
#pragma once
#include <ck/utility/ignore.hpp>
#include <ck/utility/type.hpp>
#include "ck/ck.hpp"
#ifdef CK_CODE_GEN_RTC
@@ -17,7 +18,7 @@ namespace ck {
template <typename T, uint32_t seed_t, ck::enable_if_t<std::is_same<float, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
uint32_t x = bit_cast<uint32_t>(val);
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits ^= x >> 16;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
@@ -33,7 +34,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
template <typename T, uint32_t seed_t, ck::enable_if_t<std::is_same<_Float16, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
uint16_t x = bit_cast<uint16_t>(val);
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
drop_bits *= 0x7000149;

View File

@@ -24,7 +24,7 @@ struct prand_generator_t<float, seed_>
{
CK_TILE_HOST_DEVICE uint32_t operator()(int id, float val, uint32_t seed = seed_)
{
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
uint32_t x = bit_cast<uint32_t>(val);
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits ^= x >> 16;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
@@ -43,7 +43,7 @@ struct prand_generator_t<half_t, seed_>
{
CK_TILE_HOST_DEVICE uint32_t operator()(int id, half_t val, uint32_t seed = seed_)
{
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
uint16_t x = bit_cast<uint16_t>(val);
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
drop_bits *= 0x7000149;