diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp index 5fbebb708d..217b339b66 100644 --- a/include/ck/utility/f8_utils.hpp +++ b/include/ck/utility/f8_utils.hpp @@ -5,6 +5,8 @@ #include "ck/utility/data_type.hpp" +// these conversions are disabled if native conversions available +#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__) #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 namespace ck { @@ -242,4 +244,5 @@ __host__ __device__ Y cast_from_f8(X x) } } // namespace ck::utils -#endif +#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 +#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__) diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 5c5447f94e..70619ee0a5 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -85,6 +85,19 @@ inline __host__ __device__ constexpr bhalf_t type_convert(int8_ template <> inline __host__ __device__ f8_t type_convert(float x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + union + { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // not endian independent + } val; + val.fval = x; + uint32_t ival = 0; + ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0 + val.i32val = ival; + return val.i8val[0]; +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::standard; @@ -92,20 +105,33 @@ inline __host__ __device__ f8_t type_convert(float x) return utils:: cast_to_f8(x, rng); +#endif } // convert fp8 to fp32 template <> inline __host__ __device__ float type_convert(f8_t x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + float fval; + uint32_t i32val = static_cast(x); + fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); + // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + return fval; +#else constexpr bool negative_zero_nan = true; return utils::cast_from_f8(x); +#endif } // convert fp16 to fp8 template <> inline __host__ __device__ f8_t type_convert(half_t x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // convert to float and use native converion + return type_convert(type_convert(x)); +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::standard; @@ -113,14 +139,20 @@ inline __host__ __device__ f8_t type_convert(half_t x) return utils:: cast_to_f8( x, rng); +#endif } // convert fp8 to fp16 template <> inline __host__ __device__ half_t type_convert(f8_t x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // use native conversion to float and convert to fp16 + return type_convert(type_convert(x)); +#else constexpr bool negative_zero_nan = true; return utils::cast_from_f8(x); +#endif } #endif @@ -129,6 +161,19 @@ inline __host__ __device__ half_t type_convert(f8_t x) template <> inline __host__ __device__ bf8_t type_convert(float x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + union + { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // not endian independent + } val; + val.fval = x; + uint32_t ival = 0; + ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0 + val.i32val = ival; + return val.i8val[0]; +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::standard; @@ -136,20 +181,33 @@ inline __host__ __device__ bf8_t type_convert(float x) return utils:: cast_to_f8( x, rng); +#endif } // convert bf8 to fp32 template <> inline __host__ __device__ float type_convert(bf8_t x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + float fval; + uint32_t i32val = static_cast(x); + fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); + // asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + return fval; +#else constexpr bool negative_zero_nan = true; return utils::cast_from_f8(x); +#endif } // convert fp16 to bf8 template <> inline __host__ __device__ bf8_t type_convert(half_t x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // convert to float and use native converion + return type_convert(type_convert(x)); +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::standard; @@ -157,14 +215,20 @@ inline __host__ __device__ bf8_t type_convert(half_t x) return utils:: cast_to_f8( x, rng); +#endif } // convert bf8 to fp16 template <> inline __host__ __device__ half_t type_convert(bf8_t x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // use native conversion to float and convert to fp16 + return type_convert(type_convert(x)); +#else constexpr bool negative_zero_nan = true; return utils::cast_from_f8(x); +#endif } #endif @@ -234,30 +298,47 @@ __host__ __device__ constexpr Y f8_convert_sr(X x); template <> inline __host__ __device__ f8_t f8_convert_sr(float x) { + constexpr int seed = 42; + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + union + { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // not endian independent + } val; + val.fval = x; + uint32_t ival = 0; + ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos + val.i32val = ival; + return val.i8val[0]; // little endian +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; - constexpr int seed = 42; - // as thread id is not available on host, use 0 for prn generation - uint32_t rng = prand_generator(reinterpret_cast(&x), x); return utils:: cast_to_f8(x, rng); +#endif } // convert fp16 to fp8 with stochastic rounding template <> inline __host__ __device__ f8_t f8_convert_sr(half_t x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // convert to float and use native converion + return f8_convert_sr(type_convert(x)); +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr int seed = 42; - // as thread id is not available on host, use 0 for prn generation uint32_t rng = prand_generator(reinterpret_cast(&x), x); return utils:: cast_to_f8( x, rng); +#endif } #endif @@ -266,21 +347,38 @@ inline __host__ __device__ f8_t f8_convert_sr(half_t x) template <> inline __host__ __device__ bf8_t f8_convert_sr(float x) { + constexpr int seed = 42; + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + union + { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // not endian independent + } val; + val.fval = x; + uint32_t ival = 0; + ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos + val.i32val = ival; + return val.i8val[0]; // little endian +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; - constexpr int seed = 42; - // as thread id is not available on host, use 0 for prn generation - uint32_t rng = prand_generator(reinterpret_cast(&x), x); return utils:: cast_to_f8( x, rng); +#endif } // convert fp16 to bf8 with stochastic rounding template <> inline __host__ __device__ bf8_t f8_convert_sr(half_t x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // convert to float and use native converion + return f8_convert_sr(type_convert(x)); +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; @@ -290,6 +388,7 @@ inline __host__ __device__ bf8_t f8_convert_sr(half_t x) return utils:: cast_to_f8( x, rng); +#endif } #endif