From 7fbc128e832daa491913bbafd0c43d95e24ed1d5 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Thu, 3 Apr 2025 12:42:03 -0500 Subject: [PATCH] Add FP16/BF16<->FP8/BF8 conversions (#2035) * Move conversion functions and add missing conversions * Add tests * Add missing conversions * Add missing conversions * Add bf8 tests * Update clipping for vectors * Add missing conversions * Add bf16 fp8 tests * Add bf16 bf8 tests * Fix device conversion * Fix conversions * Fix vector use * Minor fix * Add a workaround flag * Add a workaround flag for bf16 conversion * Add another workaround * Add a workaround for fp16 to bf8 conversion * Update type alias * Add docstrings and missing wrappers * Fix if defined macros * Fix more if defined macros * Add comments * Remove __host__ specifier * Add a gfx950 guard * Update function naming [ROCm/composable_kernel commit: 265af71a71fd81c99988365477973c337c512e13] --- include/ck/ck.hpp | 6 + include/ck/utility/amd_ck_fp8.hpp | 864 +++++++++++++++++++-- include/ck/utility/mxf8_utils.hpp | 2 +- include/ck/utility/scaled_type_convert.hpp | 4 +- include/ck/utility/type_convert.hpp | 696 ++++++++++++++++- test/data_type/test_bf8_ocp.cpp | 595 +++++++++++++- test/data_type/test_fp8_ocp.cpp | 571 +++++++++++++- 7 files changed, 2628 insertions(+), 110 deletions(-) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 5fa73d2fda..1d49b68a32 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -248,6 +248,12 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) // workaround: compiler issue on gfx950 #define CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION 1 +// workaround: compiler issue on gfx950 +#define CK_WORKAROUND_FP16_TO_FP8_CONVERSION 1 + +// workaround: compiler issue on gfx950 +#define CK_WORKAROUND_BF16_TO_FP8_CONVERSION 1 + // denorm test fix, necessary for gfx90a #ifndef CK_GFX90A_DENORM_WORKAROUND #define CK_GFX90A_DENORM_WORKAROUND 0 diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index 5c80c42d6c..b0089bb2d1 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -64,6 +64,9 @@ enum class ck_saturation_t namespace fp8_impl { typedef fp8_storage_t fp8x2_storage_t __attribute__((ext_vector_type(2))); +typedef _Float16 half2_t __attribute__((ext_vector_type(2))); +typedef ushort ushortx2_t __attribute__((ext_vector_type(2))); +typedef short shortx2_t __attribute__((ext_vector_type(2))); typedef float float2_t __attribute__((ext_vector_type(2))); __host__ __device__ static inline constexpr bool fnuz_f8_is_nan(f8_fnuz_t a) @@ -270,7 +273,7 @@ static __host__ __device__ float cast_to_f32_from_f8(fp8_storage_t v) } template -static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v) +static __device__ float2_t cast_to_f32_from_f8(fp8x2_storage_t v) { const auto i16val = bit_cast(v); @@ -458,6 +461,510 @@ __is_interpret_supported([[maybe_unused]] ck_fp8_interpretation_t interp) #endif } +#if defined(__gfx950__) +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0) +{ + union + { + unsigned int i32val; + half2_t half_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr unsigned int i32val = 0; + val.half_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0); + } + } + + val.i32val = + __builtin_amdgcn_cvt_scalef32_sr_fp8_f16(i32val, val.half_vec[0], rng, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0) +{ + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_f16(v[0], rng), + cast_to_f8_from_f16(v[1], rng)}; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0) +{ + union + { + unsigned int i32val; + half2_t half_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr unsigned int i32val = 0; + val.half_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0); + } + } + + val.i32val = + __builtin_amdgcn_cvt_scalef32_sr_bf8_f16(i32val, val.half_vec[0], rng, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0) +{ + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_f16(v[0], rng), + cast_to_f8_from_f16(v[1], rng)}; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0) +{ + std::ignore = rng; + + union + { + unsigned int i32val; + half2_t half_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.half_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0); + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0) +{ +#if CK_WORKAROUND_FP16_TO_FP8_CONVERSION + return fp8x2_storage_t{ + cast_to_f8_from_f16(v[0], rng), + cast_to_f8_from_f16(v[1], rng)}; +#else + std::ignore = rng; + + union + { + half2_t half_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.half_vec = v; + + if constexpr(saturate) + { + if((val.i16_vec[0] & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0); + } + if((val.i16_vec[1] & 0x7FFF) != 0x7FFF) + { + val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 448.0, -448.0); + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0); + + return fp8x2_storage_t{val.i8val[0], val.i8val[1]}; +#endif +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0) +{ + std::ignore = rng; + + union + { + unsigned int i32val; + half2_t half_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.half_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0); + } + } + + val.half_vec = + __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0) +{ +#if CK_WORKAROUND_FP16_TO_FP8_CONVERSION + return fp8x2_storage_t{ + cast_to_f8_from_f16(v[0], rng), + cast_to_f8_from_f16(v[1], rng)}; +#else + std::ignore = rng; + + union + { + half2_t half_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.half_vec = v; + + if constexpr(saturate) + { + if((val.i16_vec[0] & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0); + } + if((val.i16_vec[1] & 0x7FFF) != 0x7FFF) + { + val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 57344.0, -57344.0); + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0); + + return fp8x2_storage_t{val.i8val[0], val.i8val[1]}; +#endif +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0) +{ + union + { + unsigned int i32val; + ushortx2_t bhalf_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr unsigned int i32val = 0; + val.bhalf_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = + ushort((bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >> + 16)); // convert to float and back + } + } + + val.i32val = __builtin_amdgcn_cvt_scalef32_sr_fp8_bf16( + i32val, val.bhalf_vec[0], rng, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0) +{ + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_bf16(v[0], rng), + cast_to_f8_from_bf16(v[1], rng)}; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0) +{ + union + { + unsigned int i32val; + ushortx2_t bhalf_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr unsigned int i32val = 0; + val.bhalf_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = ushort( + (bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >> + 16)); // convert to float and back + } + } + + val.i32val = __builtin_amdgcn_cvt_scalef32_sr_bf8_bf16( + i32val, val.bhalf_vec[0], rng, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0) +{ + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_bf16(v[0], rng), + cast_to_f8_from_bf16(v[1], rng)}; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0) +{ + std::ignore = rng; + + union + { + unsigned int i32val; + ushortx2_t bhalf_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.bhalf_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = + ushort((bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >> + 16)); // convert to float and back + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0) +{ +#if CK_WORKAROUND_BF16_TO_FP8_CONVERSION + return fp8x2_storage_t{ + cast_to_f8_from_bf16(v[0], rng), + cast_to_f8_from_bf16(v[1], rng)}; +#else + std::ignore = rng; + + union + { + ushortx2_t bhalf_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.bhalf_vec = v; + + if constexpr(saturate) + { + if((val.i16_vec[0] & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = + ushort((bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >> + 16)); // convert to float and back + } + if((val.i16_vec[1] & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[1] = + ushort((bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[1]} << 16), 448.0, -448.0)) >> + 16)); // convert to float and back + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0); + + return fp8x2_storage_t{val.i8val[0], val.i8val[1]}; +#endif +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0) +{ + std::ignore = rng; + + union + { + unsigned int i32val; + ushortx2_t bhalf_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.bhalf_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = ushort( + (bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >> + 16)); // convert to float and back + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0) +{ + std::ignore = rng; + + union + { + ushortx2_t bhalf_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.bhalf_vec = v; + + if constexpr(saturate) + { + if((val.i16_vec[0] & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = ushort( + (bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >> + 16)); // convert to float and back + } + if((val.i16_vec[1] & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[1] = ushort( + (bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[1]} << 16), 57344.0, -57344.0)) >> + 16)); // convert to float and back + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0); + + return fp8x2_storage_t{val.i8val[0], val.i8val[1]}; +} +#endif // defined(__gfx950__) + #if CK_FP8_CVT_FAST_PATH // The conversion function is from rocblas // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79 @@ -523,6 +1030,84 @@ static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng = } return i8data; } + +template +static __device__ fp8x2_storage_t cast_to_f8_from_f32(float2_t v, unsigned int rng = 0) +{ + if constexpr(stochastic_rounding) + { + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_f32(v[0], rng), + cast_to_f8_from_f32(v[1], rng)}; + } + else + { + union + { + float fval; + unsigned int i32val; + unsigned char i8val[4]; + } val0, val1; + + val0.fval = v[0]; + val1.fval = v[1]; + + unsigned int ival = 0; + + if constexpr(saturate) + { + if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) + { + if((val0.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 240.0, -240.0); + } + if((val1.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 240.0, -240.0); + } + } + else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP) + { // OCP type + if((val0.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 448.0, -448.0); + } + if((val1.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 448.0, -448.0); + } + } + else + { + if((val0.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 57344.0, -57344.0); + } + if((val1.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 57344.0, -57344.0); + } + } + } + + // RNE CVT + if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) || + (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)) + { + ival = __builtin_amdgcn_cvt_pk_fp8_f32(val0.fval, val1.fval, ival, false); + } + else + { + ival = __builtin_amdgcn_cvt_pk_bf8_f32(val0.fval, val1.fval, ival, false); + } + + val0.i32val = ival; + + return fp8x2_storage_t{val0.i8val[0], val0.i8val[1]}; + } +} #endif // CK_FP8_CVT_FAST_PATH // The conversion function is from rocblas @@ -797,6 +1382,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn * * \tparam interp interpretation of fp8 * \tparam sat saturation of fp8 + * \tparam stochastic_rounding switch between RNE and SR * \param f float number * \return fp8_storage_t */ @@ -882,6 +1468,47 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) #endif // CK_FP8_CVT_FAST_PATH } +/** + * \brief convert vector of 2 floats to vector of 2 @p fp8_storage_t + * + * \tparam interp interpretation of fp8 + * \tparam sat saturation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param f vector of 2 floats + * \return fp8x2_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH +__device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f) +{ + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&f), f[0]); +#else + rng = prand_generator(reinterpret_cast(&f), f[0]); +#endif + } + return cast_to_f8_from_f32( + f, rng); +#else +#if CK_USE_OCP_FP8 +__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f) +{ +#else +__host__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f) +{ +#endif // CK_USE_OCP_FP8 + return fp8x2_storage_t{cvt_float_to_fp8(f[0]), + cvt_float_to_fp8(f[1])}; +#endif // CK_FP8_CVT_FAST_PATH +} + /** * \brief convert _Float16 to @p fp8_storage_t * @@ -900,87 +1527,168 @@ __host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 __host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x) #endif { - return cvt_float_to_fp8(static_cast(x)); + { + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&x), x); +#else + rng = prand_generator(reinterpret_cast(&x), x); +#endif + } +#if defined(__gfx950__) + return cast_to_f8_from_f16(x, rng); +#else + std::ignore = rng; + return cvt_float_to_fp8( + static_cast(x)); +#endif // defined(__gfx950__) + } +} + +/** + * \brief convert vector of 2 _Float16 to vector of 2 @p fp8_storage_t + * + * \tparam sat saturation of fp8 + * \tparam interp interpretation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param x vector of 2 _Float16 + * \return fp8x2_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 +__host__ __device__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x) +#else +__host__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x) +#endif +{ + { + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&x), x[0]); +#else + rng = prand_generator(reinterpret_cast(&x), x[0]); +#endif + } +#if defined(__gfx950__) + return cast_to_f8_from_f16(x, rng); +#else + std::ignore = rng; + return cvt_float_to_fp8( + float2_t{static_cast(x[0]), static_cast(x[1])}); +#endif // defined(__gfx950__) + } +} + +/** + * \brief convert bhalf_t to @p fp8_storage_t + * + * \tparam sat saturation of fp8 + * \tparam interp interpretation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param x bhalf_t value + * \return fp8_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 +__host__ __device__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x) +#else +__host__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x) +#endif +{ + { + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&x), + static_cast(x)); +#else + rng = prand_generator(reinterpret_cast(&x), static_cast(x)); +#endif + } +#if defined(__gfx950__) + return cast_to_f8_from_bf16(x, rng); +#else + std::ignore = rng; + return cvt_float_to_fp8( + bit_cast(uint32_t{x} << 16)); // convert value to float +#endif // defined(__gfx950__) + } +} + +/** + * \brief convert vector of 2 bhalf_t to vector of 2 @p fp8_storage_t + * + * \tparam sat saturation of fp8 + * \tparam interp interpretation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param x vector of 2 bhalf_t + * \return fp8x2_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 +__host__ __device__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x) +#else +__host__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x) +#endif +{ +#if CK_WORKAROUND_BF16_TO_FP8_CONVERSION + return cvt_float_to_fp8( + float2_t{bit_cast(uint32_t{x[0]} << 16), + bit_cast(uint32_t{x[1]} << 16)}); // convert values to float +#else // CK_WORKAROUND_BF16_TO_FP8_CONVERSION + { + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&x), + static_cast(x[0])); +#else + rng = prand_generator(reinterpret_cast(&x), + static_cast(x[0])); +#endif + } +#if defined(__gfx950__) + return cast_to_f8_from_bf16(x, rng); +#else + std::ignore = rng; + return cvt_float_to_fp8( + float2_t{bit_cast(uint32_t{x[0]} << 16), + bit_cast(uint32_t{x[1]} << 16)}); // convert values to float +#endif // defined(__gfx950__) + } +#endif // CK_WORKAROUND_BF16_TO_FP8_CONVERSION } } // namespace fp8_impl -// Declare a template function for fp8 conversion using RNE -template -__host__ __device__ constexpr Y f8_convert_rne(X x); - -// convert fp32 to fp8 with rounding to nearest even -template <> -inline __host__ __device__ f8_ocp_t f8_convert_rne(float x) -{ - return f8_ocp_t{ - fp8_impl::cvt_float_to_fp8(x)}; -} - -// convert fp32 to bf8 with rounding to nearest even -template <> -inline __host__ __device__ bf8_ocp_t f8_convert_rne(float x) -{ - return bf8_ocp_t{ - fp8_impl::cvt_float_to_fp8(x)}; -} - -// convert _Float16 to fp8 with rounding to nearest even -template <> -inline __host__ __device__ f8_ocp_t f8_convert_rne(_Float16 x) -{ - return f8_ocp_t{ - fp8_impl::cvt_half_t_to_fp8(x)}; -} - -template <> -inline __host__ __device__ bf8_ocp_t f8_convert_rne(_Float16 x) -{ - return bf8_ocp_t{ - fp8_impl::cvt_half_t_to_fp8( - x)}; -} - -// Declare a template function for fp8 conversion using RNE -template -__host__ __device__ constexpr Y f8_convert_sr(X x); - -// convert fp32 to fp8 with stochastic rounding -template <> -inline __host__ __device__ f8_ocp_t f8_convert_sr(float x) -{ - return f8_ocp_t{ - fp8_impl::cvt_float_to_fp8( - x)}; -} - -// convert fp32 to bf8 with stochastic rounding -template <> -inline __host__ __device__ bf8_ocp_t f8_convert_sr(float x) -{ - return bf8_ocp_t{fp8_impl::cvt_float_to_fp8(x)}; -} - -// convert _Float16 to fp8 with stochastic rounding -template <> -inline __host__ __device__ f8_ocp_t f8_convert_sr(_Float16 x) -{ - return f8_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; -} - -// convert _Float16 to bf8 with stochastic rounding -template <> -inline __host__ __device__ bf8_ocp_t f8_convert_sr(_Float16 x) -{ - return bf8_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; -} - #if CK_USE_OCP_FP8 using f8_t = f8_ocp_t; using bf8_t = bf8_ocp_t; diff --git a/include/ck/utility/mxf8_utils.hpp b/include/ck/utility/mxf8_utils.hpp index b7b98c6455..9046a24a3a 100644 --- a/include/ck/utility/mxf8_utils.hpp +++ b/include/ck/utility/mxf8_utils.hpp @@ -39,7 +39,7 @@ static __device__ float cast_to_f32_from_f8_scaled(float scale, fp8_storage_t v) } template -static __device__ float2_t cast_to_f32x2_from_f8x2_scaled(float scale, fp8x2_storage_t v) +static __device__ float2_t cast_to_f32_from_f8_scaled(float scale, fp8x2_storage_t v) { const auto i16val = bit_cast(v); diff --git a/include/ck/utility/scaled_type_convert.hpp b/include/ck/utility/scaled_type_convert.hpp index 9a9c53caec..f3e2bd3dd9 100644 --- a/include/ck/utility/scaled_type_convert.hpp +++ b/include/ck/utility/scaled_type_convert.hpp @@ -67,7 +67,7 @@ inline __host__ float2_t scaled_type_convert(e8m0_bexp_t s #endif { #if CK_MX_FP8_CVT_FAST_PATH - return fp8_impl::cast_to_f32x2_from_f8x2_scaled( + return fp8_impl::cast_to_f32_from_f8_scaled( type_convert(scale), x.AsType()[Number<0>{}]); #else return float2_t{scaled_type_convert(scale, x.AsType()[Number<0>{}]), @@ -86,7 +86,7 @@ inline __host__ float2_t scaled_type_convert(e8m0_bexp_t #endif { #if CK_MX_FP8_CVT_FAST_PATH - return fp8_impl::cast_to_f32x2_from_f8x2_scaled( + return fp8_impl::cast_to_f32_from_f8_scaled( type_convert(scale), x.AsType()[Number<0>{}]); #else return float2_t{scaled_type_convert(scale, x.AsType()[Number<0>{}]), diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index b9aeb44999..c8127aa887 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -117,7 +117,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert(float #if CK_USE_RNE_BF16_CONVERSION return bf16_convert_rtn(x); #else - return uint16_t(u.int32 >> 16); + return uint16_t(uint32_t{x} >> 16); #endif } @@ -356,6 +356,180 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr(half_t x #endif } +/** + * @brief Converts a float to a 8-bit float type (f8_ocp_t) using stochastic rounding. + * + * @param x The input float value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t f8_convert_sr(float x) +{ + return f8_ocp_t{ + fp8_impl::cvt_float_to_fp8( + x)}; +} + +/** + * @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (f8_ocp_t) using + * stochastic rounding. + * + * @param x The input vector of 2 floats. + * @return The converted vector of 2 f8_ocp_t. + */ +template <> +inline __host__ __device__ f8x2_ocp_t f8_convert_sr(float2_t x) +{ + return f8x2_ocp_t{ + fp8_impl::cvt_float_to_fp8( + x)}; +} + +/** + * @brief Converts a float to a 8-bit float type (bf8_ocp_t) using stochastic rounding. + * + * @param x The input float value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_sr(float x) +{ + return bf8_ocp_t{fp8_impl::cvt_float_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (bf8_ocp_t) using + * stochastic rounding. + * + * @param x The input vector of 2 floats. + * @return The converted vector of 2 bf8_ocp_t. + */ +template <> +inline __host__ __device__ bf8x2_ocp_t f8_convert_sr(float2_t x) +{ + return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8(x)}; +} + +/** + * @brief Converts a half_t to a 8-bit float type (f8_ocp_t) using stochastic rounding. + * + * @param x The input half_t value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t f8_convert_sr(half_t x) +{ + return f8_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (f8_ocp_t) using + * stochastic rounding. + * + * @param x The input vector of 2 half_t. + * @return The converted vector of 2 f8_ocp_t. + */ +template <> +inline __host__ __device__ f8x2_ocp_t f8_convert_sr(half2_t x) +{ + return f8x2_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; +} + +/** + * @brief Converts a half_t to a 8-bit half_t type (bf8_ocp_t) using stochastic rounding. + * + * @param x The input half_t value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_sr(half_t x) +{ + return bf8_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (bf8_ocp_t) using + * stochastic rounding. + * + * @param x The input vector of 2 half_t. + * @return The converted vector of 2 bf8_ocp_t. + */ +template <> +inline __host__ __device__ bf8x2_ocp_t f8_convert_sr(half2_t x) +{ + return bf8x2_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; +} + +/** + * @brief Converts a bhalf_t to a 8-bit float type (f8_ocp_t) using stochastic rounding. + * + * @param x The input bhalf_t value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t f8_convert_sr(bhalf_t x) +{ + return f8_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (f8_ocp_t) using + * stochastic rounding. + * + * @param x The input vector of 2 bhalf_t. + * @return The converted vector of 2 f8_ocp_t. + */ +template <> +inline __host__ __device__ f8x2_ocp_t f8_convert_sr(bhalf2_t x) +{ + return f8x2_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8(x)}; +} + +/** + * @brief Converts a bhalf_t to a 8-bit half_t type (bf8_ocp_t) using stochastic rounding. + * + * @param x The input bhalf_t value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_sr(bhalf_t x) +{ + return bf8_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (bf8_ocp_t) using + * stochastic rounding. + * + * @param x The input vector of 2 bhalf_t. + * @return The converted vector of 2 bf8_ocp_t. + */ +template <> +inline __host__ __device__ bf8x2_ocp_t f8_convert_sr(bhalf2_t x) +{ + return bf8x2_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8(x)}; +} + // Declare a template function for fp8 conversion using RNE template __host__ __device__ constexpr Y f8_convert_rne(X x); @@ -466,6 +640,172 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_rne(half_t #endif } +/** + * @brief Converts a float to a 8-bit float type (f8_ocp_t) using rounding to nearest/even. + * + * @param x The input float value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t f8_convert_rne(float x) +{ + return f8_ocp_t{ + fp8_impl::cvt_float_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (f8_ocp_t) using rounding + * to nearest/even. + * + * @param x The input vector of 2 floats. + * @return The converted vector of 2 f8_ocp_t. + */ +template <> +inline __host__ __device__ f8x2_ocp_t f8_convert_rne(float2_t x) +{ + return f8x2_ocp_t{ + fp8_impl::cvt_float_to_fp8(x)}; +} + +/** + * @brief Converts a float to a 8-bit float type (bf8_ocp_t) using rounding to nearest/even. + * + * @param x The input float value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_rne(float x) +{ + return bf8_ocp_t{ + fp8_impl::cvt_float_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (bf8_ocp_t) using + * rounding to nearest/even. + * + * @param x The input vector of 2 floats. + * @return The converted vector of 2 bf8_ocp_t. + */ +template <> +inline __host__ __device__ bf8x2_ocp_t f8_convert_rne(float2_t x) +{ + return bf8x2_ocp_t{ + fp8_impl::cvt_float_to_fp8(x)}; +} + +/** + * @brief Converts a half_t to a 8-bit float type (f8_ocp_t) using rounding to nearest/even. + * + * @param x The input half_t value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t f8_convert_rne(half_t x) +{ + return f8_ocp_t{ + fp8_impl::cvt_half_t_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (f8_ocp_t) using rounding + * to nearest/even. + * + * @param x The input vector of 2 half_t. + * @return The converted vector of 2 f8_ocp_t. + */ +template <> +inline __host__ __device__ f8x2_ocp_t f8_convert_rne(half2_t x) +{ + return f8x2_ocp_t{ + fp8_impl::cvt_half_t_to_fp8(x)}; +} + +/** + * @brief Converts a half_t to a 8-bit half_t type (bf8_ocp_t) using rounding to nearest/even. + * + * @param x The input half_t value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_rne(half_t x) +{ + return bf8_ocp_t{ + fp8_impl::cvt_half_t_to_fp8( + x)}; +} + +/** + * @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (bf8_ocp_t) using + * rounding to nearest/even. + * + * @param x The input vector of 2 half_t. + * @return The converted vector of 2 bf8_ocp_t. + */ +template <> +inline __host__ __device__ bf8x2_ocp_t f8_convert_rne(half2_t x) +{ + return bf8x2_ocp_t{ + fp8_impl::cvt_half_t_to_fp8( + x)}; +} + +/** + * @brief Converts a bhalf_t to a 8-bit float type (f8_ocp_t) using rounding to nearest/even. + * + * @param x The input bhalf_t value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t f8_convert_rne(bhalf_t x) +{ + return f8_ocp_t{ + fp8_impl::cvt_bhalf_t_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (f8_ocp_t) using + * rounding to nearest/even. + * + * @param x The input vector of 2 bhalf_t. + * @return The converted vector of 2 f8_ocp_t. + */ +template <> +inline __host__ __device__ f8x2_ocp_t f8_convert_rne(bhalf2_t x) +{ + return f8x2_ocp_t{ + fp8_impl::cvt_bhalf_t_to_fp8(x)}; +} + +/** + * @brief Converts a bhalf_t to a 8-bit half_t type (bf8_ocp_t) using rounding to nearest/even. + * + * @param x The input bhalf_t value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_rne(bhalf_t x) +{ + return bf8_ocp_t{ + fp8_impl::cvt_bhalf_t_to_fp8( + x)}; +} + +/** + * @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (bf8_ocp_t) using + * rounding to nearest/even. + * + * @param x The input vector of 2 bhalf_t. + * @return The converted vector of 2 bf8_ocp_t. + */ +template <> +inline __host__ __device__ bf8x2_ocp_t f8_convert_rne(bhalf2_t x) +{ + return bf8x2_ocp_t{ + fp8_impl::cvt_bhalf_t_to_fp8( + x)}; +} + // convert fp32 to fp8 template <> inline __host__ __device__ f8_fnuz_t type_convert(float x) @@ -477,17 +817,6 @@ inline __host__ __device__ f8_fnuz_t type_convert(float x) #endif } -// convert fp32 to fp8 -template <> -inline __host__ __device__ f8_ocp_t type_convert(float x) -{ -#if CK_USE_SR_F8_CONVERSION - return f8_convert_sr(x); -#else - return f8_convert_rne(x); -#endif -} - // convert fp8 to fp32 template <> inline __host__ __device__ float type_convert(f8_fnuz_t x) @@ -524,12 +853,39 @@ inline __host__ __device__ float2_t type_convert(f8x2_fnu #endif } +/** + * @brief Converts a f8_ocp_t value to a float value. + * + * @param x The input f8_ocp_t value. + * @return The converted float value. + */ +template <> +inline __host__ __device__ float type_convert(f8_ocp_t x) +{ +#if CK_OCP_FP8_CVT_FAST_PATH + union + { + unsigned int i32val; + fp8_storage_t i8val[4]; + } val; + val.i8val[0] = x.data; + return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0); +#else + return fp8_impl::cast_from_f8(x.data); +#endif +} + +/** + * @brief Converts a vector of 2 f8_ocp_t values to a vector of 2 float values. + * + * @param x The input vector of 2 f8_ocp_t values. + * @return The converted vector of 2 float values. + */ template <> inline __host__ __device__ float2_t type_convert(f8x2_ocp_t x) { #if CK_OCP_FP8_CVT_FAST_PATH - return fp8_impl::cast_to_f32x2_from_f8x2( - x.AsType()[Number<0>{}]); + return __builtin_amdgcn_cvt_pk_f32_fp8(bit_cast(x), false); #else return float2_t{fp8_impl::cast_from_f8( x.AsType()[Number<0>{}]), @@ -538,6 +894,229 @@ inline __host__ __device__ float2_t type_convert(f8x2_ocp_ #endif } +/** + * @brief Converts a f8_ocp_t value to a half_t value. + * + * @param x The input f8_ocp_t value. + * @return The converted half_t value. + */ +template <> +inline __host__ __device__ half_t type_convert(f8_ocp_t x) +{ +#if defined(__gfx950__) + union + { + uint16_t i16val; + fp8_storage_t i8val[2]; + } input; + input.i8val[0] = x.data; + + union + { + half2_t half_vec; + half_t half_arr[2]; + } output; + output.half_vec = __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(input.i16val, /*scale*/ 1.f, 0); + + return output.half_arr[0]; +#else + return fp8_impl::cast_from_f8(x.data); +#endif +} + +/** + * @brief Converts a vector of 2 f8_ocp_t values to a vector of 2 half_t values. + * + * @param x The input vector of 2 f8_ocp_t values. + * @return The converted vector of 2 half_t values. + */ +template <> +inline __host__ __device__ half2_t type_convert(f8x2_ocp_t x) +{ +#if defined(__gfx950__) + return __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(bit_cast(x), /*scale*/ 1.f, 0); +#else + return half2_t{type_convert(float(x.AsType()[Number<0>{}])), + type_convert(float(x.AsType()[Number<1>{}]))}; +#endif +} + +/** + * @brief Converts a f8_ocp_t value to a bhalf_t value. + * + * @param x The input f8_ocp_t value. + * @return The converted bhalf_t value. + */ +template <> +inline __host__ __device__ bhalf_t type_convert(f8_ocp_t x) +{ +#if defined(__gfx950__) + union + { + uint16_t i16val; + fp8_storage_t i8val[2]; + } input; + input.i8val[0] = x.data; + + union + { + bhalf2_t bhalf_vec; + bhalf_t bhalf_arr[2]; + } output; + output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(input.i16val, /*scale*/ 1.f, 0); + + return output.bhalf_arr[0]; +#else + return type_convert( + fp8_impl::cast_from_f8(x.data)); +#endif +} + +/** + * @brief Converts a vector of 2 f8_ocp_t values to a vector of 2 bhalf_t values. + * + * @param x The input vector of 2 f8_ocp_t values. + * @return The converted vector of 2 bhalf_t values. + */ +template <> +inline __host__ __device__ bhalf2_t type_convert(f8x2_ocp_t x) +{ +#if defined(__gfx950__) + return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(bit_cast(x), /*scale*/ 1.f, 0); +#else + return bhalf2_t{type_convert(float(x.AsType()[Number<0>{}])), + type_convert(float(x.AsType()[Number<1>{}]))}; +#endif +} + +/** + * @brief Converts a bf8_ocp_t value to a float value. + * + * @param x The input bf8_ocp_t value. + * @return The converted float value. + */ +template <> +inline __host__ __device__ float type_convert(bf8_ocp_t x) +{ +#if CK_OCP_FP8_CVT_FAST_PATH + union + { + unsigned int i32val; + fp8_storage_t i8val[4]; + } val; + val.i8val[0] = x.data; + return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0); +#else + return fp8_impl::cast_from_f8(x.data); +#endif +} + +/** + * @brief Converts a vector of 2 bf8_ocp_t values to a vector of 2 float values. + * + * @param x The input vector of 2 bf8_ocp_t values. + * @return The converted vector of 2 float values. + */ +template <> +inline __host__ __device__ float2_t type_convert(bf8x2_ocp_t x) +{ +#if CK_OCP_FP8_CVT_FAST_PATH + return __builtin_amdgcn_cvt_pk_f32_bf8(bit_cast(x), false); +#else + return float2_t{fp8_impl::cast_from_f8( + x.AsType()[Number<0>{}]), + fp8_impl::cast_from_f8( + x.AsType()[Number<1>{}])}; +#endif +} + +/** + * @brief Converts a bf8_ocp_t value to a half_t value. + * + * @param x The input bf8_ocp_t value. + * @return The converted half_t value. + */ +template <> +inline __host__ __device__ half_t type_convert(bf8_ocp_t x) +{ +#if defined(__gfx950__) + union + { + uint16_t i16val; + fp8_storage_t i8val[2]; + } val; + val.i8val[0] = x.data; + return __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(val.i16val, /*scale*/ 1.f, 0)[0]; +#else + return fp8_impl::cast_from_f8(x.data); +#endif +} + +/** + * @brief Converts a vector of 2 bf8_ocp_t values to a vector of 2 half_t values. + * + * @param x The input vector of 2 bf8_ocp_t values. + * @return The converted vector of 2 half_t values. + */ +template <> +inline __host__ __device__ half2_t type_convert(bf8x2_ocp_t x) +{ +#if defined(__gfx950__) + return __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(bit_cast(x), /*scale*/ 1.f, 0); +#else + return half2_t{type_convert(float(x.AsType()[Number<0>{}])), + type_convert(float(x.AsType()[Number<1>{}]))}; +#endif +} + +/** + * @brief Converts a bf8_ocp_t value to a bhalf_t value. + * + * @param x The input bf8_ocp_t value. + * @return The converted bhalf_t value. + */ +template <> +inline __host__ __device__ bhalf_t type_convert(bf8_ocp_t x) +{ +#if defined(__gfx950__) + union + { + uint16_t i16val; + fp8_storage_t i8val[2]; + } input; + input.i8val[0] = x.data; + + union + { + bhalf2_t bhalf_vec; + bhalf_t bhalf_arr[2]; + } output; + output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.i16val, /*scale*/ 1.f, 0); + + return output.bhalf_arr[0]; +#else + return type_convert( + fp8_impl::cast_from_f8(x.data)); +#endif +} + +/** + * @brief Converts a vector of 2 bf8_ocp_t values to a vector of 2 bhalf_t values. + * + * @param x The input vector of 2 bf8_ocp_t values. + * @return The converted vector of 2 bhalf_t values. + */ +template <> +inline __host__ __device__ bhalf2_t type_convert(bf8x2_ocp_t x) +{ +#if defined(__gfx950__) + return __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(bit_cast(x), /*scale*/ 1.f, 0); +#else + return bhalf2_t{type_convert(float(x.AsType()[Number<0>{}])), + type_convert(float(x.AsType()[Number<1>{}]))}; +#endif +} + template <> inline __host__ __device__ float2_t type_convert(pk_i4_t x) { @@ -610,7 +1189,12 @@ inline __host__ __device__ f8_fnuz_t type_convert(half_t x) #endif } -// convert fp16 to fp8 +/** + * @brief Converts a half_t value to a f8_ocp_t value with rounding determined by a flag. + * + * @param x The input half_t value. + * @return The converted f8_ocp_t value. + */ template <> inline __host__ __device__ f8_ocp_t type_convert(half_t x) { @@ -621,6 +1205,22 @@ inline __host__ __device__ f8_ocp_t type_convert(half_t x) #endif } +/** + * @brief Converts a half_t value to a bf8_ocp_t value with rounding determined by a flag. + * + * @param x The input half_t value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t type_convert(half_t x) +{ +#if CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_rne(x); +#endif +} + // convert fp8 to fp16 template <> inline __host__ __device__ half_t type_convert(f8_fnuz_t x) @@ -645,7 +1245,28 @@ inline __host__ __device__ bf8_fnuz_t type_convert(float x) #endif } -// convert fp32 to bf8 +/** + * @brief Converts a float value to a f8_ocp_t value with rounding determined by a flag. + * + * @param x The input float value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t type_convert(float x) +{ +#if CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_rne(x); +#endif +} + +/** + * @brief Converts a float value to a bf8_ocp_t value with rounding determined by a flag. + * + * @param x The input float value. + * @return The converted bf8_ocp_t value. + */ template <> inline __host__ __device__ bf8_ocp_t type_convert(float x) { @@ -656,6 +1277,38 @@ inline __host__ __device__ bf8_ocp_t type_convert(float x) #endif } +/** + * @brief Converts a bhalf_t value to a f8_ocp_t value with rounding determined by a flag. + * + * @param x The input bhalf_t value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t type_convert(bhalf_t x) +{ +#if CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_rne(x); +#endif +} + +/** + * @brief Converts a bhalf_t value to a bf8_ocp_t value with rounding determined by a flag. + * + * @param x The input bhalf_t value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t type_convert(bhalf_t x) +{ +#if CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_rne(x); +#endif +} + // convert bf8 to fp32 template <> inline __host__ __device__ float type_convert(bf8_fnuz_t x) @@ -683,17 +1336,6 @@ inline __host__ __device__ bf8_fnuz_t type_convert(half_t x) #endif } -// convert fp16 to bf8 -template <> -inline __host__ __device__ bf8_ocp_t type_convert(half_t x) -{ -#if CK_USE_SR_F8_CONVERSION - return f8_convert_sr(x); -#else - return f8_convert_rne(x); -#endif -} - // convert bf8 to fp16 template <> inline __host__ __device__ half_t type_convert(bf8_fnuz_t x) diff --git a/test/data_type/test_bf8_ocp.cpp b/test/data_type/test_bf8_ocp.cpp index 9d4ee38b15..285e7e69fc 100644 --- a/test/data_type/test_bf8_ocp.cpp +++ b/test/data_type/test_bf8_ocp.cpp @@ -1,13 +1,19 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" +#include "ck/library/utility/device_memory.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" using ck::bf8_ocp_t; +using ck::bf8x2_ocp_t; +using ck::bhalf2_t; +using ck::bhalf_t; using ck::f8_convert_rne; using ck::f8_convert_sr; +using ck::float2_t; +using ck::half2_t; using ck::half_t; using ck::type_convert; @@ -266,3 +272,590 @@ TEST(BF8OCP, ConvertFP16Stochastic) const auto bf8_nan = f8_convert_sr(ck::NumericLimits::QuietNaN()); ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data)); } + +constexpr uint64_t test_size = 256 + 6; + +__host__ __device__ void +test_fp32_bf8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + uint8_t bf8_uid = static_cast(bf8_id); + auto v = type_convert(bf8_ocp_t{bf8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + + /// Test vector conversion + // bf8x2 -> fp32x2 + bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16 + + float2_t f32x2 = type_convert(bf8x2); + p_test[i++] = f32x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = f32x2[1]; + if(i >= N) + { + return; + } + + // fp32x2 -> bf8x2 + f32x2 = {-4.0f, 2.0f}; + bf8x2 = f8_convert_rne(f32x2); // expect {-4, 2} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + bf8x2 = f8_convert_sr(f32x2); // expect {-4, 2} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } +} + +TEST(BF8OCP, HostFP32BF8Convert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_fp32_bf8_type_convert(test_size, out.data(), &completed); + + std::set bf8_nan_ids; + bf8_nan_ids.insert(0b11111111); + bf8_nan_ids.insert(0b01111111); + bf8_nan_ids.insert(0b11111101); + bf8_nan_ids.insert(0b01111101); + bf8_nan_ids.insert(0b11111110); + bf8_nan_ids.insert(0b01111110); + for(auto bf8_nan_id : bf8_nan_ids) + { + auto idx = bf8_nan_id; + ASSERT_TRUE(std::isnan(out[idx])); + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = bf8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_uid})) + << " bf8_id: " << bf8_id << std::endl + << type_convert(bf8_ocp_t{bf8_uid}); + } + + // /// Test vector conversions + + auto i = 256; + + // bf8x2 -> fp32x2 + EXPECT_EQ(out[i++], -powf(2.0f, -14.0f)); + EXPECT_EQ(out[i++], powf(2.0f, -16.0f)); + + // fp32x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + // SR + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void device_test_fp32_bf8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed) +{ + test_fp32_bf8_type_convert(N, p_test, p_completed); +} + +TEST(BF8OCP, DeviceFP32BF8Convert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + device_test_fp32_bf8_type_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + std::set bf8_nan_ids; + bf8_nan_ids.insert(0b11111111); + bf8_nan_ids.insert(0b01111111); + bf8_nan_ids.insert(0b11111101); + bf8_nan_ids.insert(0b01111101); + bf8_nan_ids.insert(0b11111110); + bf8_nan_ids.insert(0b01111110); + for(auto bf8_nan_id : bf8_nan_ids) + { + auto idx = bf8_nan_id; + ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx]; + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = bf8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_uid})) + << " bf8_id: " << bf8_id << std::endl + << type_convert(bf8_ocp_t{bf8_uid}); + } + + /// Test vector conversions + + auto i = 256; + + // bf8x2 -> fp32x2 + EXPECT_EQ(out[i++], -powf(2.0f, -14.0f)); + EXPECT_EQ(out[i++], powf(2.0f, -16.0f)); + + // fp32x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + // SR + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__host__ __device__ void +test_fp16_bf8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + uint8_t bf8_uid = static_cast(bf8_id); + auto v = type_convert(bf8_ocp_t{bf8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + + /// Test vector conversion + // bf8x2 -> fp16x2 + bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16 + + half2_t f16x2 = type_convert(bf8x2); + p_test[i++] = f16x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = f16x2[1]; + if(i >= N) + { + return; + } + + // fp16x2 -> bf8x2 + f16x2 = {-4.0f, 2.0f}; + bf8x2 = f8_convert_rne(f16x2); // expect {-4, 2} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + bf8x2 = f8_convert_sr(f16x2); // expect {-4, 2} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } +} + +TEST(BF8OCP, HostFP16BF8Convert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_fp16_bf8_type_convert(test_size, out.data(), &completed); + + std::set bf8_nan_ids; + bf8_nan_ids.insert(0b11111111); + bf8_nan_ids.insert(0b01111111); + bf8_nan_ids.insert(0b11111101); + bf8_nan_ids.insert(0b01111101); + bf8_nan_ids.insert(0b11111110); + bf8_nan_ids.insert(0b01111110); + for(auto bf8_nan_id : bf8_nan_ids) + { + auto idx = bf8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))); + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = bf8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_uid})) + << " bf8_id: " << bf8_id << std::endl + << type_convert(type_convert(bf8_ocp_t{bf8_uid})); + } + + // /// Test vector conversions + + auto i = 256; + + // bf8x2 -> fp16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -14.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -16.0f))); + + // fp16x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void device_test_fp16_bf8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed) +{ + test_fp16_bf8_type_convert(N, p_test, p_completed); +} + +TEST(BF8OCP, DeviceFP16BF8Convert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(half_t)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + device_test_fp16_bf8_type_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + std::set bf8_nan_ids; + bf8_nan_ids.insert(0b11111111); + bf8_nan_ids.insert(0b01111111); + bf8_nan_ids.insert(0b11111101); + bf8_nan_ids.insert(0b01111101); + bf8_nan_ids.insert(0b11111110); + bf8_nan_ids.insert(0b01111110); + for(auto bf8_nan_id : bf8_nan_ids) + { + auto idx = bf8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))) + << "idx: " << idx << " out[idx]: " << type_convert(out[idx]); + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = bf8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_uid})) + << " bf8_id: " << bf8_id << std::endl + << type_convert(type_convert(bf8_ocp_t{bf8_uid})); + } + + /// Test vector conversions + + auto i = 256; + + // bf8x2 -> fp16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -14.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -16.0f))); + + // fp16x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__host__ __device__ void +test_bf16_bf8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + uint8_t bf8_uid = static_cast(bf8_id); + auto v = type_convert(bf8_ocp_t{bf8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + + /// Test vector conversion + // bf8x2 -> bf16x2 + bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16 + + bhalf2_t bf16x2 = type_convert(bf8x2); + p_test[i++] = bf16x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = bf16x2[1]; + if(i >= N) + { + return; + } + + // bf16x2 -> bf8x2 + bf16x2 = {type_convert(-4.0f), type_convert(2.0f)}; + bf8x2 = f8_convert_rne(bf16x2); // expect {-4, 2} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + bf8x2 = f8_convert_sr(bf16x2); // expect {-4, 2} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } +} + +TEST(BF8OCP, HostBF16BF8Convert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_bf16_bf8_type_convert(test_size, out.data(), &completed); + + std::set bf8_nan_ids; + bf8_nan_ids.insert(0b11111111); + bf8_nan_ids.insert(0b01111111); + bf8_nan_ids.insert(0b11111101); + bf8_nan_ids.insert(0b01111101); + bf8_nan_ids.insert(0b11111110); + bf8_nan_ids.insert(0b01111110); + for(auto bf8_nan_id : bf8_nan_ids) + { + auto idx = bf8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))); + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = bf8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_uid})) + << " bf8_id: " << bf8_id << std::endl + << type_convert(type_convert(bf8_ocp_t{bf8_uid})); + } + + // /// Test vector conversions + + auto i = 256; + + // bf8x2 -> bf16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -14.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -16.0f))); + + // bf16x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void +device_test_bf16_bf8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed) +{ + test_bf16_bf8_type_convert(N, p_test, p_completed); +} + +TEST(BF8OCP, DeviceBF16BF8Convert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(bhalf_t)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + device_test_bf16_bf8_type_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + std::set bf8_nan_ids; + bf8_nan_ids.insert(0b11111111); + bf8_nan_ids.insert(0b01111111); + bf8_nan_ids.insert(0b11111101); + bf8_nan_ids.insert(0b01111101); + bf8_nan_ids.insert(0b11111110); + bf8_nan_ids.insert(0b01111110); + for(auto bf8_nan_id : bf8_nan_ids) + { + auto idx = bf8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))) + << "idx: " << idx << " out[idx]: " << type_convert(out[idx]); + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = bf8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_uid})) + << " bf8_id: " << bf8_id << std::endl + << type_convert(type_convert(bf8_ocp_t{bf8_uid})); + } + + /// Test vector conversions + + auto i = 256; + + // bf8x2 -> bf16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -14.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -16.0f))); + + // bf16x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} diff --git a/test/data_type/test_fp8_ocp.cpp b/test/data_type/test_fp8_ocp.cpp index 944dd89930..bf562112c8 100644 --- a/test/data_type/test_fp8_ocp.cpp +++ b/test/data_type/test_fp8_ocp.cpp @@ -1,13 +1,19 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" +#include "ck/library/utility/device_memory.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" +using ck::bhalf2_t; +using ck::bhalf_t; using ck::f8_convert_rne; using ck::f8_convert_sr; using ck::f8_ocp_t; +using ck::f8x2_ocp_t; +using ck::float2_t; +using ck::half2_t; using ck::half_t; using ck::type_convert; @@ -248,3 +254,566 @@ TEST(FP8OCP, ConvertFP16Stochastic) auto f8_nan = f8_convert_sr(ck::NumericLimits::QuietNaN()); ASSERT_TRUE(ck::fp8_impl::ocp_f8_is_nan(f8_nan.data)); } + +constexpr uint64_t test_size = 256 + 6; + +__host__ __device__ void +test_fp32_fp8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + uint8_t fp8_uid = static_cast(fp8_id); + auto v = type_convert(f8_ocp_t{fp8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + + /// Test vector conversion + // fp8x2 -> fp32x2 + f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9 + + float2_t f32x2 = type_convert(fp8x2); + p_test[i++] = f32x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = f32x2[1]; + if(i >= N) + { + return; + } + + // fp32x2 -> fp8x2 + f32x2 = {-4.0f, 2.0f}; + fp8x2 = f8_convert_rne(f32x2); // expect {-4, 2} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + fp8x2 = f8_convert_sr(f32x2); // expect {-4, 2} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } +} + +TEST(FP8OCP, HostFP32FP8Convert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_fp32_fp8_type_convert(test_size, out.data(), &completed); + + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = fp8_nan_id; + ASSERT_TRUE(std::isnan(out[idx])); + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = fp8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(f8_ocp_t{fp8_uid})) + << " fp8_id: " << fp8_id << std::endl + << type_convert(f8_ocp_t{fp8_uid}); + } + + // /// Test vector conversions + + auto i = 256; + + // fp8x2 -> fp32x2 + EXPECT_EQ(out[i++], -powf(2.0f, -6.0f)); + EXPECT_EQ(out[i++], powf(2.0f, -9.0f)); + + // fp32x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + // SR + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void device_test_fp32_fp8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed) +{ + test_fp32_fp8_type_convert(N, p_test, p_completed); +} + +TEST(FP8OCP, DeviceFP32FP8Convert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + device_test_fp32_fp8_type_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = fp8_nan_id; + ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx]; + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = fp8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(f8_ocp_t{fp8_uid})) + << " fp8_id: " << fp8_id << std::endl + << type_convert(f8_ocp_t{fp8_uid}); + } + + /// Test vector conversions + + auto i = 256; + + // fp8x2 -> fp32x2 + EXPECT_EQ(out[i++], -powf(2.0f, -6.0f)); + EXPECT_EQ(out[i++], powf(2.0f, -9.0f)); + + // fp32x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + // SR + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__host__ __device__ void +test_fp16_fp8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + uint8_t fp8_uid = static_cast(fp8_id); + auto v = type_convert(f8_ocp_t{fp8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + + /// Test vector conversion + // fp8x2 -> fp16x2 + f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9 + + half2_t f16x2 = type_convert(fp8x2); + p_test[i++] = f16x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = f16x2[1]; + if(i >= N) + { + return; + } + + // fp16x2 -> fp8x2 + f16x2 = {-4.0f, 2.0f}; + fp8x2 = f8_convert_rne(f16x2); // expect {-4, 2} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + fp8x2 = f8_convert_sr(f16x2); // expect {-4, 2} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } +} + +TEST(FP8OCP, HostFP16FP8Convert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_fp16_fp8_type_convert(test_size, out.data(), &completed); + + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = fp8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))); + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = fp8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(f8_ocp_t{fp8_uid})) + << " fp8_id: " << fp8_id << std::endl + << type_convert(type_convert(f8_ocp_t{fp8_uid})); + } + + // /// Test vector conversions + + auto i = 256; + + // fp8x2 -> fp16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -6.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -9.0f))); + + // fp16x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void device_test_fp16_fp8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed) +{ + test_fp16_fp8_type_convert(N, p_test, p_completed); +} + +TEST(FP8OCP, DeviceFP16FP8Convert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(half_t)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + device_test_fp16_fp8_type_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = fp8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))) + << "idx: " << idx << " out[idx]: " << type_convert(out[idx]); + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = fp8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(f8_ocp_t{fp8_uid})) + << " fp8_id: " << fp8_id << std::endl + << type_convert(type_convert(f8_ocp_t{fp8_uid})); + } + + /// Test vector conversions + + auto i = 256; + + // fp8x2 -> fp16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -6.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -9.0f))); + + // fp16x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__host__ __device__ void +test_bf16_fp8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + uint8_t fp8_uid = static_cast(fp8_id); + auto v = type_convert(f8_ocp_t{fp8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + + /// Test vector conversion + // fp8x2 -> bf16x2 + f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9 + + bhalf2_t bf16x2 = type_convert(fp8x2); + p_test[i++] = bf16x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = bf16x2[1]; + if(i >= N) + { + return; + } + + // bf16x2 -> fp8x2 + bf16x2 = {type_convert(-4.0f), type_convert(2.0f)}; + fp8x2 = f8_convert_rne(bf16x2); // expect {-4, 2} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + fp8x2 = f8_convert_sr(bf16x2); // expect {-4, 2} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } +} + +TEST(FP8OCP, HostBF16FP8Convert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_bf16_fp8_type_convert(test_size, out.data(), &completed); + + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = fp8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))); + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = fp8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(f8_ocp_t{fp8_uid})) + << " fp8_id: " << fp8_id << std::endl + << type_convert(type_convert(f8_ocp_t{fp8_uid})); + } + + // /// Test vector conversions + + auto i = 256; + + // fp8x2 -> bf16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -6.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -9.0f))); + + // bf16x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void +device_test_bf16_fp8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed) +{ + test_bf16_fp8_type_convert(N, p_test, p_completed); +} + +TEST(FP8OCP, DeviceBF16FP8Convert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(bhalf_t)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + device_test_bf16_fp8_type_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = fp8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))) + << "idx: " << idx << " out[idx]: " << type_convert(out[idx]); + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = fp8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(f8_ocp_t{fp8_uid})) + << " fp8_id: " << fp8_id << std::endl + << type_convert(type_convert(f8_ocp_t{fp8_uid})); + } + + /// Test vector conversions + + auto i = 256; + + // fp8x2 -> bf16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -6.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -9.0f))); + + // bf16x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +}