mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add accelerated stochastic rounding on gfx950 (#2355)
* Add native prand generation support for gfx950
* Update seed calculation
[ROCm/composable_kernel commit: dbfe70e72a]
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/enable_if.hpp"
|
||||
#include "ck/utility/get_id.hpp"
|
||||
#include "ck/utility/random_gen.hpp"
|
||||
#include "ck/utility/functional.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
@@ -1396,12 +1397,18 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
|
||||
#else
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
|
||||
#endif
|
||||
#endif // #ifndef CK_CODE_GEN_RTC
|
||||
#endif // #if defined(__gfx950__)
|
||||
}
|
||||
return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
|
||||
f, rng);
|
||||
@@ -1416,12 +1423,18 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
|
||||
#else
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
|
||||
#endif
|
||||
#endif // #ifndef CK_CODE_GEN_RTC
|
||||
#endif // #if defined(__gfx950__)
|
||||
}
|
||||
|
||||
if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
|
||||
@@ -1487,12 +1500,18 @@ __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f)
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
|
||||
#else
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f[0]);
|
||||
#endif
|
||||
#endif // #ifndef CK_CODE_GEN_RTC
|
||||
#endif // #if defined(__gfx950__)
|
||||
}
|
||||
return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
|
||||
f, rng);
|
||||
@@ -1532,12 +1551,18 @@ __host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
#endif // #ifndef CK_CODE_GEN_RTC
|
||||
#endif // #if defined(__gfx950__)
|
||||
}
|
||||
#if defined(__gfx950__)
|
||||
return cast_to_f8_from_f16<interp,
|
||||
@@ -1574,12 +1599,18 @@ __host__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x)
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
|
||||
#else
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
|
||||
#endif
|
||||
#endif // #ifndef CK_CODE_GEN_RTC
|
||||
#endif // #if defined(__gfx950__)
|
||||
}
|
||||
#if defined(__gfx950__)
|
||||
return cast_to_f8_from_f16<interp,
|
||||
@@ -1616,13 +1647,19 @@ __host__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x)
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
|
||||
static_cast<float>(x));
|
||||
#else
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), static_cast<float>(x));
|
||||
#endif
|
||||
#endif // #ifndef CK_CODE_GEN_RTC
|
||||
#endif // #if defined(__gfx950__)
|
||||
}
|
||||
#if defined(__gfx950__)
|
||||
return cast_to_f8_from_bf16<interp,
|
||||
@@ -1664,14 +1701,20 @@ __host__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x)
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
|
||||
static_cast<float>(x[0]));
|
||||
#else
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x),
|
||||
static_cast<float>(x[0]));
|
||||
#endif
|
||||
#endif // #ifndef CK_CODE_GEN_RTC
|
||||
#endif // #if defined(__gfx950__)
|
||||
}
|
||||
#if defined(__gfx950__)
|
||||
return cast_to_f8_from_bf16<interp,
|
||||
|
||||
@@ -197,8 +197,9 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const fl
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
}
|
||||
return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
|
||||
}
|
||||
@@ -221,8 +222,9 @@ __host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const
|
||||
uint32_t rng = 0;
|
||||
if constexpr(stochastic_rounding)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
}
|
||||
return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/f8_utils.hpp"
|
||||
#include "ck/utility/get_id.hpp"
|
||||
#include "ck/utility/mxf4_utils.hpp"
|
||||
#include "ck/utility/mxf6_utils.hpp"
|
||||
#include "ck/utility/random_gen.hpp"
|
||||
@@ -234,12 +235,18 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
|
||||
template <>
|
||||
inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
#endif // #ifndef CK_CODE_GEN_RTC
|
||||
#endif // #if defined(__gfx950__)
|
||||
#if defined(__gfx94__)
|
||||
union
|
||||
{
|
||||
@@ -296,12 +303,18 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
|
||||
template <>
|
||||
inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
#endif // #ifndef CK_CODE_GEN_RTC
|
||||
#endif // #if defined(__gfx950__)
|
||||
#if defined(__gfx94__)
|
||||
union
|
||||
{
|
||||
@@ -1446,13 +1459,10 @@ inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0
|
||||
// convert fp32 to fp4 with stochastic rounding
|
||||
inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
union
|
||||
{
|
||||
uint32_t bitwise;
|
||||
@@ -1468,6 +1478,12 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
|
||||
value.bitwise, float_values.float2_array, rng, scale, 0);
|
||||
return value.f4_array[0];
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
return utils::sat_convert_to_type_sr<f4_t>(x / scale, rng);
|
||||
#endif
|
||||
}
|
||||
@@ -1475,13 +1491,10 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
|
||||
// convert vector of 2 fp32 to vector of 2 fp4 with sr
|
||||
inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
union
|
||||
{
|
||||
uint32_t bitwise;
|
||||
@@ -1499,6 +1512,12 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
|
||||
#endif // CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION
|
||||
return value.f4x2_array[0];
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
|
||||
#endif
|
||||
union
|
||||
{
|
||||
uint32_t bitwise;
|
||||
@@ -1514,13 +1533,10 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
|
||||
// convert vector of 32 fp32 to vector of 32 fp4 with sr
|
||||
inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
union
|
||||
{
|
||||
__uint128_t bitwise;
|
||||
@@ -1546,6 +1562,12 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f
|
||||
|
||||
return f4_values.f4x32_array;
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
|
||||
#endif
|
||||
union
|
||||
{
|
||||
__uint128_t bitwise;
|
||||
@@ -1776,13 +1798,10 @@ inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0
|
||||
*/
|
||||
inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
union
|
||||
{
|
||||
float32_t float_vector;
|
||||
@@ -1799,6 +1818,12 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
|
||||
|
||||
return out.f6_array[0];
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
return utils::sat_convert_to_type_sr<f6_t>(x / scale, rng);
|
||||
#endif
|
||||
}
|
||||
@@ -1815,6 +1840,12 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
|
||||
*/
|
||||
inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale);
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
union
|
||||
{
|
||||
@@ -1828,9 +1859,6 @@ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f
|
||||
uint32_t rng =
|
||||
prand_generator<float, seed>(reinterpret_cast<size_t>(&x), float_values.float_array[0]);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale);
|
||||
#else
|
||||
union
|
||||
{
|
||||
float32_t float_vector;
|
||||
@@ -2044,13 +2072,10 @@ inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1
|
||||
*/
|
||||
inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
union
|
||||
{
|
||||
float32_t float_vector;
|
||||
@@ -2067,6 +2092,12 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
|
||||
|
||||
return out.bf6_array[0];
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
return utils::sat_convert_to_type_sr<bf6_t>(x / scale, rng);
|
||||
#endif
|
||||
}
|
||||
@@ -2085,6 +2116,12 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
|
||||
*/
|
||||
inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.0f)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale);
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
union
|
||||
{
|
||||
@@ -2098,9 +2135,6 @@ inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.
|
||||
uint32_t rng =
|
||||
prand_generator<float, seed>(reinterpret_cast<size_t>(&x), float_values.float_array[0]);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale);
|
||||
#else
|
||||
union
|
||||
{
|
||||
float32_t float_vector;
|
||||
|
||||
Reference in New Issue
Block a user