From 4e27eae99d11caab824feecb7e94149779139a96 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Mon, 27 Nov 2023 20:06:17 -0600 Subject: [PATCH] Switch default f8 conversion to stochastic rounding (#1048) * Switch default f8 conversion to stochastic rounding * Refactor f8-related type_converts * Add an element-wise op [ROCm/composable_kernel commit: 6ef034f6cad7ec70b3a06518bec7fef8def11d51] --- include/ck/ck.hpp | 3 + .../element/unary_element_wise_operation.hpp | 18 + include/ck/utility/type_convert.hpp | 524 ++++++++++-------- 3 files changed, 307 insertions(+), 238 deletions(-) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 1e41404192..4a2b5c0ad7 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -134,6 +134,9 @@ // inner product using V_DOT with DPP8 modifiers #define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1 +// set stochastic rounding as default for f8 conversions +#define CK_USE_SR_F8_CONVERSION 1 + // block synchronization only s_wait lgkmcnt(0), not vmcnt(0) #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index e72b122cfc..e9c85964c5 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -281,6 +281,24 @@ struct ConvertF8SR } }; +struct ConvertF8RNE +{ + // convert to fp8 using rounding to nearest even + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + // check Y datatype + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + // check X datatype + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = f8_convert_rne(x); + } +}; + struct Scale { __host__ __device__ Scale(float scale) : scale_(scale) {} diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 1d754fabe4..70bc6f278c 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -95,243 +95,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert(int8_ return type_convert(x_fp32); } -// convert fp32 to fp8 -template <> -inline __host__ __device__ f8_t type_convert(float x) -{ -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) - float max_fp8 = 240.0f; - x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); - union - { - float fval; - uint32_t i32val; - uint8_t i8val[4]; // not endian independent - } val; - val.fval = x; - uint32_t ival = 0; - ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0 - val.i32val = ival; - return val.i8val[0]; -#else - constexpr bool negative_zero_nan = true; - constexpr bool clip = true; - constexpr f8_rounding_mode rm = f8_rounding_mode::standard; - constexpr uint32_t rng = 0; - return utils:: - cast_to_f8(x, - rng); -#endif -} - -// convert fp8 to fp32 -template <> -inline __host__ __device__ float type_convert(f8_t x) -{ -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) - float fval; - uint32_t i32val = static_cast(x); - fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); - // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); - return fval; -#else - constexpr bool negative_zero_nan = true; - return utils::cast_from_f8(x); -#endif -} - -template <> -inline __host__ __device__ float2_t type_convert(f8x2_t x) -{ -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) - const auto i16val = bit_cast(x); - return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0); -#else - constexpr bool negative_zero_nan = true; - const auto f8x2_v = vector_type(x); - vector_type f32x2_v; - f32x2_v.template AsType()(Number<0>{}) = - utils::cast_from_f8( - f8x2_v.template AsType()[Number<0>{}]); - f32x2_v.template AsType()(Number<1>{}) = - utils::cast_from_f8( - f8x2_v.template AsType()[Number<1>{}]); - return f32x2_v.template AsType()[Number<0>{}]; -#endif -} - -template <> -inline __host__ __device__ half2_t type_convert(float2_t x) -{ - - const vector_type f32x2_v(x); - const auto y = __builtin_amdgcn_cvt_pkrtz(f32x2_v.template AsType()[Number<0>{}], - f32x2_v.template AsType()[Number<1>{}]); - return bit_cast(y); -} - -// convert fp16 to fp8 -template <> -inline __host__ __device__ f8_t type_convert(half_t x) -{ -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) - // convert to float and use native converion - return type_convert(type_convert(x)); -#else - constexpr bool negative_zero_nan = true; - constexpr bool clip = true; - constexpr f8_rounding_mode rm = f8_rounding_mode::standard; - constexpr uint32_t rng = 0; - return utils:: - cast_to_f8( - x, rng); -#endif -} - -// convert fp8 to fp16 -template <> -inline __host__ __device__ half_t type_convert(f8_t x) -{ -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) - // use native conversion to float and convert to fp16 - return type_convert(type_convert(x)); -#else - constexpr bool negative_zero_nan = true; - return utils::cast_from_f8(x); -#endif -} - -// convert fp32 to bf8 -template <> -inline __host__ __device__ bf8_t type_convert(float x) -{ -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) - union - { - float fval; - uint32_t i32val; - uint8_t i8val[4]; // not endian independent - } val; - val.fval = x; - uint32_t ival = 0; - ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0 - val.i32val = ival; - return val.i8val[0]; -#else - constexpr bool negative_zero_nan = true; - constexpr bool clip = true; - constexpr f8_rounding_mode rm = f8_rounding_mode::standard; - constexpr uint32_t rng = 0; - return utils:: - cast_to_f8( - x, rng); -#endif -} - -// convert bf8 to fp32 -template <> -inline __host__ __device__ float type_convert(bf8_t x) -{ -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) - float fval; - uint32_t i32val = static_cast(x); - fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); - // asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); - return fval; -#else - constexpr bool negative_zero_nan = true; - return utils::cast_from_f8(x); -#endif -} - -// convert fp16 to bf8 -template <> -inline __host__ __device__ bf8_t type_convert(half_t x) -{ -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) - // convert to float and use native converion - return type_convert(type_convert(x)); -#else - constexpr bool negative_zero_nan = true; - constexpr bool clip = true; - constexpr f8_rounding_mode rm = f8_rounding_mode::standard; - constexpr uint32_t rng = 0; - return utils:: - cast_to_f8( - x, rng); -#endif -} - -// convert bf8 to fp16 -template <> -inline __host__ __device__ half_t type_convert(bf8_t x) -{ -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) - // use native conversion to float and convert to fp16 - return type_convert(type_convert(x)); -#else - constexpr bool negative_zero_nan = true; - return utils::cast_from_f8(x); -#endif -} - -// Declare a template function for bf16 conversion using RTN -template -__host__ __device__ constexpr Y bf16_convert_rtn(X x); - -// Convert fp32 to bf16 with RTN if higher precision is needed -template <> -inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(float x) -{ - union - { - float fp32; - uint32_t int32; - } u = {x}; - - // When the exponent bits are not all 1s, then the value is zero, normal, - // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus - // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). - // This causes the bfloat16's mantissa to be incremented by 1 if the 16 - // least significant bits of the float mantissa are greater than 0x8000, - // or if they are equal to 0x8000 and the least significant bit of the - // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when - // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already - // has the value 0x7f, then incrementing it causes it to become 0x00 and - // the exponent is incremented by one, which is the next higher FP value - // to the unrounded bfloat16 value. When the bfloat16 value is subnormal - // with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up - // to a normal value with an exponent of 0x01 and a mantissa of 0x00. - // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, - // incrementing it causes it to become an exponent of 0xFF and a mantissa - // of 0x00, which is Inf, the next higher value to the unrounded value. - bool flag0 = ~u.int32 & 0x7f800000; - - // When all of the exponent bits are 1, the value is Inf or NaN. - // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero - // mantissa bit. Quiet NaN is indicated by the most significant mantissa - // bit being 1. Signaling NaN is indicated by the most significant - // mantissa bit being 0 but some other bit(s) being 1. If any of the - // lower 16 bits of the mantissa are 1, we set the least significant bit - // of the bfloat16 mantissa, in order to preserve signaling NaN in case - // the bfloat16's mantissa bits are all 0. - bool flag1 = !flag0 && (u.int32 & 0xffff); - - u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even - u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN - - return uint16_t(u.int32 >> 16); -} - -// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed -template <> -inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(half_t x) -{ - float x_fp32 = static_cast(x); - - return bf16_convert_rtn(x_fp32); -} - // Declare a template function for fp8 conversion using SR template __host__ __device__ constexpr Y f8_convert_sr(X x); @@ -343,6 +106,8 @@ inline __host__ __device__ f8_t f8_convert_sr(float x) constexpr int seed = 42; uint32_t rng = prand_generator(reinterpret_cast(&x), x); #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + float max_fp8 = 240.0f; + x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); union { float fval; @@ -423,7 +188,6 @@ inline __host__ __device__ bf8_t f8_convert_sr(half_t x) constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr int seed = 42; - // as thread id is not available on host, use 0 for prn generation uint32_t rng = prand_generator(reinterpret_cast(&x), x); return utils:: cast_to_f8( @@ -431,4 +195,288 @@ inline __host__ __device__ bf8_t f8_convert_sr(half_t x) #endif } +// Declare a template function for fp8 conversion using RNE +template +__host__ __device__ constexpr Y f8_convert_rne(X x); + +// convert fp32 to fp8 with rounding to nearest even +template <> +inline __host__ __device__ f8_t f8_convert_rne(float x) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + float max_fp8 = 240.0f; + x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); + union + { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // not endian independent + } val; + val.fval = x; + uint32_t ival = 0; + ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0 + val.i32val = ival; + return val.i8val[0]; +#else + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::standard; + constexpr uint32_t rng = 0; + return utils:: + cast_to_f8(x, + rng); +#endif +} + +// convert fp16 to fp8 with rounding to nearest even +template <> +inline __host__ __device__ f8_t f8_convert_rne(half_t x) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // convert to float and use native converion + return f8_convert_rne(type_convert(x)); +#else + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::standard; + constexpr uint32_t rng = 0; + return utils:: + cast_to_f8( + x, rng); +#endif +} + +// convert fp32 to bf8 with rounding to nearest even +template <> +inline __host__ __device__ bf8_t f8_convert_rne(float x) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + union + { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // not endian independent + } val; + val.fval = x; + uint32_t ival = 0; + ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0 + val.i32val = ival; + return val.i8val[0]; +#else + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::standard; + constexpr uint32_t rng = 0; + return utils:: + cast_to_f8( + x, rng); +#endif +} + +// convert fp16 to bf8 with rounding to nearest even +template <> +inline __host__ __device__ bf8_t f8_convert_rne(half_t x) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // convert to float and use native converion + return f8_convert_rne(type_convert(x)); +#else + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::standard; + constexpr uint32_t rng = 0; + return utils:: + cast_to_f8( + x, rng); +#endif +} + +// convert fp32 to fp8 +template <> +inline __host__ __device__ f8_t type_convert(float x) +{ +#if defined CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_rne(x); +#endif +} + +// convert fp8 to fp32 +template <> +inline __host__ __device__ float type_convert(f8_t x) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + float fval; + uint32_t i32val = static_cast(x); + fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); + // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + return fval; +#else + constexpr bool negative_zero_nan = true; + return utils::cast_from_f8(x); +#endif +} + +template <> +inline __host__ __device__ float2_t type_convert(f8x2_t x) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + const auto i16val = bit_cast(x); + return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0); +#else + constexpr bool negative_zero_nan = true; + const auto f8x2_v = vector_type(x); + vector_type f32x2_v; + f32x2_v.template AsType()(Number<0>{}) = + utils::cast_from_f8( + f8x2_v.template AsType()[Number<0>{}]); + f32x2_v.template AsType()(Number<1>{}) = + utils::cast_from_f8( + f8x2_v.template AsType()[Number<1>{}]); + return f32x2_v.template AsType()[Number<0>{}]; +#endif +} + +template <> +inline __host__ __device__ half2_t type_convert(float2_t x) +{ + + const vector_type f32x2_v(x); + const auto y = __builtin_amdgcn_cvt_pkrtz(f32x2_v.template AsType()[Number<0>{}], + f32x2_v.template AsType()[Number<1>{}]); + return bit_cast(y); +} + +// convert fp16 to fp8 +template <> +inline __host__ __device__ f8_t type_convert(half_t x) +{ +#if defined CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_nre(x); +#endif +} + +// convert fp8 to fp16 +template <> +inline __host__ __device__ half_t type_convert(f8_t x) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // use native conversion to float and convert to fp16 + return type_convert(type_convert(x)); +#else + constexpr bool negative_zero_nan = true; + return utils::cast_from_f8(x); +#endif +} + +// convert fp32 to bf8 +template <> +inline __host__ __device__ bf8_t type_convert(float x) +{ +#if defined CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_rne(x); +#endif +} + +// convert bf8 to fp32 +template <> +inline __host__ __device__ float type_convert(bf8_t x) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + float fval; + uint32_t i32val = static_cast(x); + fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); + // asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + return fval; +#else + constexpr bool negative_zero_nan = true; + return utils::cast_from_f8(x); +#endif +} + +// convert fp16 to bf8 +template <> +inline __host__ __device__ bf8_t type_convert(half_t x) +{ +#if defined CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_rne(x); +#endif +} + +// convert bf8 to fp16 +template <> +inline __host__ __device__ half_t type_convert(bf8_t x) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // use native conversion to float and convert to fp16 + return type_convert(type_convert(x)); +#else + constexpr bool negative_zero_nan = true; + return utils::cast_from_f8(x); +#endif +} + +// Declare a template function for bf16 conversion using RTN +template +__host__ __device__ constexpr Y bf16_convert_rtn(X x); + +// Convert fp32 to bf16 with RTN if higher precision is needed +template <> +inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(float x) +{ + union + { + float fp32; + uint32_t int32; + } u = {x}; + + // When the exponent bits are not all 1s, then the value is zero, normal, + // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus + // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). + // This causes the bfloat16's mantissa to be incremented by 1 if the 16 + // least significant bits of the float mantissa are greater than 0x8000, + // or if they are equal to 0x8000 and the least significant bit of the + // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when + // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already + // has the value 0x7f, then incrementing it causes it to become 0x00 and + // the exponent is incremented by one, which is the next higher FP value + // to the unrounded bfloat16 value. When the bfloat16 value is subnormal + // with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up + // to a normal value with an exponent of 0x01 and a mantissa of 0x00. + // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, + // incrementing it causes it to become an exponent of 0xFF and a mantissa + // of 0x00, which is Inf, the next higher value to the unrounded value. + bool flag0 = ~u.int32 & 0x7f800000; + + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bfloat16's mantissa bits are all 0. + bool flag1 = !flag0 && (u.int32 & 0xffff); + + u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even + u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN + + return uint16_t(u.int32 >> 16); +} + +// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed +template <> +inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(half_t x) +{ + float x_fp32 = static_cast(x); + + return bf16_convert_rtn(x_fp32); +} } // namespace ck