Add accelerated stochastic rounding on gfx950 (#2355)

* Add native prand generation support for gfx950

* Update seed calculation

[ROCm/composable_kernel commit: dbfe70e72a]
This commit is contained in:
Rostyslav Geyyer
2025-06-23 09:31:46 -05:00
committed by GitHub
parent 7c57c4f045
commit de3cfbab9a
3 changed files with 134 additions and 55 deletions

View File

@@ -5,6 +5,7 @@
#include "ck/ck.hpp"
#include "ck/utility/enable_if.hpp"
#include "ck/utility/get_id.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/functional.hpp"
#include "ck/utility/type.hpp"
@@ -1396,12 +1397,18 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
#endif
#endif // #ifndef CK_CODE_GEN_RTC
#endif // #if defined(__gfx950__)
}
return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
f, rng);
@@ -1416,12 +1423,18 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
#endif
#endif // #ifndef CK_CODE_GEN_RTC
#endif // #if defined(__gfx950__)
}
if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
@@ -1487,12 +1500,18 @@ __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f)
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f[0]);
#endif
#endif // #ifndef CK_CODE_GEN_RTC
#endif // #if defined(__gfx950__)
}
return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
f, rng);
@@ -1532,12 +1551,18 @@ __host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#endif // #ifndef CK_CODE_GEN_RTC
#endif // #if defined(__gfx950__)
}
#if defined(__gfx950__)
return cast_to_f8_from_f16<interp,
@@ -1574,12 +1599,18 @@ __host__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x)
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
#endif
#endif // #ifndef CK_CODE_GEN_RTC
#endif // #if defined(__gfx950__)
}
#if defined(__gfx950__)
return cast_to_f8_from_f16<interp,
@@ -1616,13 +1647,19 @@ __host__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x)
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
static_cast<float>(x));
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), static_cast<float>(x));
#endif
#endif // #ifndef CK_CODE_GEN_RTC
#endif // #if defined(__gfx950__)
}
#if defined(__gfx950__)
return cast_to_f8_from_bf16<interp,
@@ -1664,14 +1701,20 @@ __host__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x)
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x),
static_cast<float>(x[0]));
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x),
static_cast<float>(x[0]));
#endif
#endif // #ifndef CK_CODE_GEN_RTC
#endif // #if defined(__gfx950__)
}
#if defined(__gfx950__)
return cast_to_f8_from_bf16<interp,

View File

@@ -197,8 +197,9 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const fl
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
// use HW clock for stochastic input multiply by incremented thread id
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
}
return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
}
@@ -221,8 +222,9 @@ __host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
// use HW clock for stochastic input multiply by incremented thread id
rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
}
return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
}

View File

@@ -5,6 +5,7 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/get_id.hpp"
#include "ck/utility/mxf4_utils.hpp"
#include "ck/utility/mxf6_utils.hpp"
#include "ck/utility/random_gen.hpp"
@@ -234,12 +235,18 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
template <>
inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#endif // #ifndef CK_CODE_GEN_RTC
#endif // #if defined(__gfx950__)
#if defined(__gfx94__)
union
{
@@ -296,12 +303,18 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
template <>
inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#endif // #ifndef CK_CODE_GEN_RTC
#endif // #if defined(__gfx950__)
#if defined(__gfx94__)
union
{
@@ -1446,13 +1459,10 @@ inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0
// convert fp32 to fp4 with stochastic rounding
inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
{
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
union
{
uint32_t bitwise;
@@ -1468,6 +1478,12 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
value.bitwise, float_values.float2_array, rng, scale, 0);
return value.f4_array[0];
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
return utils::sat_convert_to_type_sr<f4_t>(x / scale, rng);
#endif
}
@@ -1475,13 +1491,10 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
// convert vector of 2 fp32 to vector of 2 fp4 with sr
inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
#endif
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
union
{
uint32_t bitwise;
@@ -1499,6 +1512,12 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
#endif // CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION
return value.f4x2_array[0];
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
#endif
union
{
uint32_t bitwise;
@@ -1514,13 +1533,10 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
// convert vector of 32 fp32 to vector of 32 fp4 with sr
inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f)
{
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
#endif
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
union
{
__uint128_t bitwise;
@@ -1546,6 +1562,12 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f
return f4_values.f4x32_array;
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
#endif
union
{
__uint128_t bitwise;
@@ -1776,13 +1798,10 @@ inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0
*/
inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
{
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
union
{
float32_t float_vector;
@@ -1799,6 +1818,12 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
return out.f6_array[0];
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
return utils::sat_convert_to_type_sr<f6_t>(x / scale, rng);
#endif
}
@@ -1815,6 +1840,12 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
*/
inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale);
#else
constexpr int seed = 1254739;
union
{
@@ -1828,9 +1859,6 @@ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f
uint32_t rng =
prand_generator<float, seed>(reinterpret_cast<size_t>(&x), float_values.float_array[0]);
#endif
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale);
#else
union
{
float32_t float_vector;
@@ -2044,13 +2072,10 @@ inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1
*/
inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
{
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
union
{
float32_t float_vector;
@@ -2067,6 +2092,12 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
return out.bf6_array[0];
#else
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
return utils::sat_convert_to_type_sr<bf6_t>(x / scale, rng);
#endif
}
@@ -2085,6 +2116,12 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
*/
inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.0f)
{
#if defined(__gfx950__)
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale);
#else
constexpr int seed = 1254739;
union
{
@@ -2098,9 +2135,6 @@ inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.
uint32_t rng =
prand_generator<float, seed>(reinterpret_cast<size_t>(&x), float_values.float_array[0]);
#endif
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale);
#else
union
{
float32_t float_vector;