From de3cfbab9aa44951b6f32b0bec88e684029c87e0 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Mon, 23 Jun 2025 09:31:46 -0500 Subject: [PATCH] Add accelerated stochastic rounding on gfx950 (#2355) * Add native prand generation support for gfx950 * Update seed calculation [ROCm/composable_kernel commit: dbfe70e72a5f2f0317b715cd4c7f7fb662affbe5] --- include/ck/utility/amd_ck_fp8.hpp | 65 +++++++++++++--- include/ck/utility/mxf8_utils.hpp | 10 ++- include/ck/utility/type_convert.hpp | 114 ++++++++++++++++++---------- 3 files changed, 134 insertions(+), 55 deletions(-) diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index d079639c6a..cdc2a4fbda 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -5,6 +5,7 @@ #include "ck/ck.hpp" #include "ck/utility/enable_if.hpp" +#include "ck/utility/get_id.hpp" #include "ck/utility/random_gen.hpp" #include "ck/utility/functional.hpp" #include "ck/utility/type.hpp" @@ -1396,12 +1397,18 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f) uint32_t rng = 0; if constexpr(stochastic_rounding) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&f), f); + rng = prand_generator(reinterpret_cast(&f), f); #else rng = prand_generator(reinterpret_cast(&f), f); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } return cast_to_f8_from_f32( f, rng); @@ -1416,12 +1423,18 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) uint32_t rng = 0; if constexpr(stochastic_rounding) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&f), f); + rng = prand_generator(reinterpret_cast(&f), f); #else rng = prand_generator(reinterpret_cast(&f), f); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ) @@ -1487,12 +1500,18 @@ __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f) uint32_t rng = 0; if constexpr(stochastic_rounding) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&f), f[0]); + rng = prand_generator(reinterpret_cast(&f), f[0]); #else rng = prand_generator(reinterpret_cast(&f), f[0]); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } return cast_to_f8_from_f32( f, rng); @@ -1532,12 +1551,18 @@ __host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x) uint32_t rng = 0; if constexpr(stochastic_rounding) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC rng = prand_generator(reinterpret_cast(&x), x); #else rng = prand_generator(reinterpret_cast(&x), x); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } #if defined(__gfx950__) return cast_to_f8_from_f16(reinterpret_cast(&x), x[0]); #else rng = prand_generator(reinterpret_cast(&x), x[0]); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } #if defined(__gfx950__) return cast_to_f8_from_f16(reinterpret_cast(&x), static_cast(x)); #else rng = prand_generator(reinterpret_cast(&x), static_cast(x)); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } #if defined(__gfx950__) return cast_to_f8_from_bf16(reinterpret_cast(&x), + rng = prand_generator(reinterpret_cast(&x), static_cast(x[0])); #else rng = prand_generator(reinterpret_cast(&x), static_cast(x[0])); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } #if defined(__gfx950__) return cast_to_f8_from_bf16(reinterpret_cast(&f), f); + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); } return cast_to_f8_from_f32_scaled(f, rng, scale); } @@ -221,8 +222,9 @@ __host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const uint32_t rng = 0; if constexpr(stochastic_rounding) { - constexpr int seed = 1254739; - rng = prand_generator(reinterpret_cast(&f), f[0]); + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); } return cast_to_f8_from_f32_scaled(f, rng, scale); } diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 5865f1dd78..2208a73860 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -5,6 +5,7 @@ #include "ck/utility/data_type.hpp" #include "ck/utility/f8_utils.hpp" +#include "ck/utility/get_id.hpp" #include "ck/utility/mxf4_utils.hpp" #include "ck/utility/mxf6_utils.hpp" #include "ck/utility/random_gen.hpp" @@ -234,12 +235,18 @@ __host__ __device__ constexpr Y f8_convert_sr(X x); template <> inline __host__ __device__ f8_fnuz_t f8_convert_sr(float x) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) #if defined(__gfx94__) union { @@ -296,12 +303,18 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr(half_t x) template <> inline __host__ __device__ bf8_fnuz_t f8_convert_sr(float x) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) #if defined(__gfx94__) union { @@ -1446,13 +1459,10 @@ inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0 // convert fp32 to fp4 with stochastic rounding inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f) { - constexpr int seed = 1254739; -#ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#else - uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#endif #if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); union { uint32_t bitwise; @@ -1468,6 +1478,12 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f) value.bitwise, float_values.float2_array, rng, scale, 0); return value.f4_array[0]; #else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#endif return utils::sat_convert_to_type_sr(x / scale, rng); #endif } @@ -1475,13 +1491,10 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f) // convert vector of 2 fp32 to vector of 2 fp4 with sr inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) { - constexpr int seed = 1254739; -#ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); -#else - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); -#endif #if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); union { uint32_t bitwise; @@ -1499,6 +1512,12 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) #endif // CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION return value.f4x2_array[0]; #else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#endif union { uint32_t bitwise; @@ -1514,13 +1533,10 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) // convert vector of 32 fp32 to vector of 32 fp4 with sr inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f) { - constexpr int seed = 1254739; -#ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); -#else - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); -#endif #if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); union { __uint128_t bitwise; @@ -1546,6 +1562,12 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f return f4_values.f4x32_array; #else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#endif union { __uint128_t bitwise; @@ -1776,13 +1798,10 @@ inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0 */ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) { - constexpr int seed = 1254739; -#ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#else - uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#endif #if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); union { float32_t float_vector; @@ -1799,6 +1818,12 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) return out.f6_array[0]; #else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#endif return utils::sat_convert_to_type_sr(x / scale, rng); #endif } @@ -1815,6 +1840,12 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) */ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); + return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale); +#else constexpr int seed = 1254739; union { @@ -1828,9 +1859,6 @@ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f uint32_t rng = prand_generator(reinterpret_cast(&x), float_values.float_array[0]); #endif -#if defined(__gfx950__) - return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale); -#else union { float32_t float_vector; @@ -2044,13 +2072,10 @@ inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1 */ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) { - constexpr int seed = 1254739; -#ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#else - uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#endif #if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); union { float32_t float_vector; @@ -2067,6 +2092,12 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) return out.bf6_array[0]; #else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#endif return utils::sat_convert_to_type_sr(x / scale, rng); #endif } @@ -2085,6 +2116,12 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) */ inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.0f) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); + return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale); +#else constexpr int seed = 1254739; union { @@ -2098,9 +2135,6 @@ inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1. uint32_t rng = prand_generator(reinterpret_cast(&x), float_values.float_array[0]); #endif -#if defined(__gfx950__) - return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale); -#else union { float32_t float_vector;