mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Fix UB caused by reinterpret_cast (#2849)
* Use bit_cast instead of reinterpret_cast to avoid UB
* Apply same fix in ck_tile
[ROCm/composable_kernel commit: 14bbc545ea]
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
#include <ck/utility/ignore.hpp>
|
||||
#include <ck/utility/type.hpp>
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
#ifdef CK_CODE_GEN_RTC
|
||||
@@ -17,7 +18,7 @@ namespace ck {
|
||||
template <typename T, uint32_t seed_t, ck::enable_if_t<std::is_same<float, T>{}, bool> = false>
|
||||
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
|
||||
{
|
||||
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
|
||||
uint32_t x = bit_cast<uint32_t>(val);
|
||||
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
|
||||
drop_bits ^= x >> 16;
|
||||
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
|
||||
@@ -33,7 +34,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
|
||||
template <typename T, uint32_t seed_t, ck::enable_if_t<std::is_same<_Float16, T>{}, bool> = false>
|
||||
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
|
||||
{
|
||||
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
|
||||
uint16_t x = bit_cast<uint16_t>(val);
|
||||
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
|
||||
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
|
||||
drop_bits *= 0x7000149;
|
||||
|
||||
@@ -24,7 +24,7 @@ struct prand_generator_t<float, seed_>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE uint32_t operator()(int id, float val, uint32_t seed = seed_)
|
||||
{
|
||||
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
|
||||
uint32_t x = bit_cast<uint32_t>(val);
|
||||
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
|
||||
drop_bits ^= x >> 16;
|
||||
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
|
||||
@@ -43,7 +43,7 @@ struct prand_generator_t<half_t, seed_>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE uint32_t operator()(int id, half_t val, uint32_t seed = seed_)
|
||||
{
|
||||
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
|
||||
uint16_t x = bit_cast<uint16_t>(val);
|
||||
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
|
||||
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
|
||||
drop_bits *= 0x7000149;
|
||||
|
||||
Reference in New Issue
Block a user